Skip to content

Commit

Permalink
Fix matrix power of identity/zero matrices
Browse files Browse the repository at this point in the history
Closes sympy/sympy#9823

Signed-off-by: Sergey B Kirpichev <skirpichev@gmail.com>
  • Loading branch information
kevalds51 authored and skirpichev committed Jul 17, 2016
1 parent 13be39a commit ddefdd9
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 6 deletions.
6 changes: 5 additions & 1 deletion sympy/matrices/expressions/matexpr.py
Expand Up @@ -106,7 +106,9 @@ def __rmul__(self, other):
def __pow__(self, other):
if not self.is_square:
raise ShapeError("Power of non-square matrix %s" % self)
if other is S.NegativeOne:
elif self.is_Identity:
return self
elif other is S.NegativeOne:
return Inverse(self)
elif other is S.Zero:
return Identity(self.rows)
Expand Down Expand Up @@ -462,6 +464,8 @@ def __pow__(self, other):
raise ShapeError("Power of non-square matrix %s" % self)
if other == 0:
return Identity(self.rows)
if other < 1:
raise ValueError("Matrix det == 0; not invertible.")
return self

def _eval_transpose(self):
Expand Down
14 changes: 10 additions & 4 deletions sympy/matrices/expressions/matpow.py
@@ -1,4 +1,4 @@
from .matexpr import MatrixExpr, ShapeError, Identity
from .matexpr import MatrixExpr, ShapeError, Identity, ZeroMatrix
from sympy.core.sympify import _sympify
from sympy.matrices import MatrixBase
from sympy.core import S
Expand Down Expand Up @@ -52,14 +52,20 @@ def doit(self, **kwargs):
args = self.args
base = args[0]
exp = args[1]
if isinstance(base, MatrixBase) and exp.is_number:
if exp.is_zero and base.is_square:
if isinstance(base, MatrixBase):
return base.func(Identity(base.shape[0]))
return Identity(base.shape[0])
elif isinstance(base, ZeroMatrix) and exp.is_negative:
raise ValueError("Matrix det == 0; not invertible.")
elif isinstance(base, (Identity, ZeroMatrix)):
return base
elif isinstance(base, MatrixBase) and exp.is_number:
if exp is S.One:
return base
return base**exp
# Note: just evaluate cases we know, return unevaluated on others.
# E.g., MatrixSymbol('x', n, m) to power 0 is not an error.
if exp.is_zero and base.is_square:
return Identity(base.shape[0])
elif exp is S.One:
return base
return MatPow(base, exp)
Expand Down
29 changes: 28 additions & 1 deletion sympy/matrices/expressions/tests/test_matpow.py
@@ -1,7 +1,7 @@
import pytest

from sympy.core import symbols, pi, S
from sympy.matrices import Identity, MatrixSymbol, ImmutableMatrix
from sympy.matrices import Identity, MatrixSymbol, ImmutableMatrix, ZeroMatrix
from sympy.matrices.expressions import MatPow, MatAdd, MatMul
from sympy.matrices.expressions.matexpr import ShapeError

Expand Down Expand Up @@ -94,3 +94,30 @@ def test_doit_nested_MatrixExpr():
Y = ImmutableMatrix([[2, 3], [4, 5]])
assert MatPow(MatMul(X, Y), 2).doit() == (X*Y)**2
assert MatPow(MatAdd(X, Y), 2).doit() == (X + Y)**2


def test_identity_power():
k = Identity(n)
assert MatPow(k, 4).doit() == k
assert MatPow(k, n).doit() == k
assert MatPow(k, -3).doit() == k
assert MatPow(k, 0).doit() == k
l = Identity(3)
assert MatPow(l, n).doit() == l
assert MatPow(l, -1).doit() == l
assert MatPow(l, 0).doit() == l


def test_zero_power():
z1 = ZeroMatrix(n, n)
assert MatPow(z1, 3).doit() == z1
pytest.raises(ValueError, lambda: MatPow(z1, -1).doit())
assert MatPow(z1, 0).doit() == Identity(n)
assert MatPow(z1, n).doit() == z1
pytest.raises(ValueError, lambda: MatPow(z1, -2).doit())
z2 = ZeroMatrix(4, 4)
assert MatPow(z2, n).doit() == z2
pytest.raises(ValueError, lambda: MatPow(z2, -3).doit())
assert MatPow(z2, 2).doit() == z2
assert MatPow(z2, 0).doit() == Identity(4)
pytest.raises(ValueError, lambda: MatPow(z2, -1).doit())
29 changes: 29 additions & 0 deletions sympy/matrices/expressions/tests/test_matrix_exprs.py
Expand Up @@ -218,3 +218,32 @@ def test_MatrixElement_doit():
u = MatrixSymbol('u', 2, 1)
v = ImmutableMatrix([3, 5])
assert u[0, 0].subs(u, v).doit() == v[0, 0]


def test_identity_powers():
M = Identity(n)
assert MatPow(M, 3).doit() == M**3
assert M**n == M
assert MatPow(M, 0).doit() == M**2
assert M**-2 == M
assert MatPow(M, -2).doit() == M**0
N = Identity(3)
assert MatPow(N, 2).doit() == N**n
assert MatPow(N, 3).doit() == N
assert MatPow(N, -2).doit() == N**4
assert MatPow(N, 2).doit() == N**0


def test_Zero_power():
z1 = ZeroMatrix(n, n)
assert z1**4 == z1
pytest.raises(ValueError, lambda: z1**-2)
assert z1**0 == Identity(n)
assert MatPow(z1, 2).doit() == z1**2
pytest.raises(ValueError, lambda: MatPow(z1, -2).doit())
z2 = ZeroMatrix(3, 3)
assert MatPow(z2, 4).doit() == z2**4
pytest.raises(ValueError, lambda: z2**-3)
assert z2**3 == MatPow(z2, 3).doit()
assert z2**0 == Identity(3)
pytest.raises(ValueError, lambda: MatPow(z2, -1).doit())

0 comments on commit ddefdd9

Please sign in to comment.