Permalink
Browse files

Merge pull request #1412 from mrocklin/transpose-decentralize

Distributed transpose logic
  • Loading branch information...
2 parents bff3595 + f802721 commit 96259182f9fcc3a63609d401e412068e779618dc @mrocklin mrocklin committed Aug 2, 2012
@@ -85,7 +85,7 @@ def _blockadd(self, other):
return MatrixExpr.__add__(self, other)
- def eval_transpose(self):
+ def _eval_transpose(self):
# Flip all the individual matrices
matrices = [Transpose(matrix) for matrix in self.mat.mat]
# Make a copy
@@ -97,7 +97,7 @@ def eval_transpose(self):
#def transpose(self):
# return self.eval_transpose()
- def eval_inverse(self, expand=False):
+ def _eval_inverse(self, expand=False):
# Inverse of one by one block matrix is easy
if self.blockshape==(1,1):
mat = Matrix(1, 1, (Inverse(self.blocks[0]), ))
@@ -184,7 +184,7 @@ def __new__(cls, *mats):
def diag(self):
return self.args[2]
- def eval_inverse(self):
+ def _eval_inverse(self):
return BlockDiagMatrix(*[Inverse(mat) for mat in self.diag])
def _blockmul(self, other):
@@ -26,30 +26,13 @@ def __new__(cls, mat, **kwargs):
if not mat.is_Matrix:
return mat**(-1)
- try:
- return mat.eval_inverse(**kwargs)
- except (AttributeError, NotImplementedError):
- pass
-
- if hasattr(mat, 'inv'):
- return mat.inv()
-
- if mat.is_Inverse:
- return mat.arg
-
- if mat.is_Identity:
- return mat
-
if not mat.is_square:
raise ShapeError("Inverse of non-square matrix %s"%mat)
- if mat.is_Mul:
- try:
- return MatMul(*[Inverse(arg) for arg in mat.args[::-1]])
- except ShapeError:
- pass
-
- return MatPow.__new__(cls, mat, -1)
+ try:
+ return mat._eval_inverse(**kwargs)
+ except (AttributeError, NotImplementedError):
+ return MatPow.__new__(cls, mat, -1)
@property
def arg(self):
@@ -59,4 +42,5 @@ def arg(self):
def shape(self):
return self.arg.shape
-from matmul import MatMul
+ def _eval_inverse(self):
+ return self.arg
@@ -55,4 +55,8 @@ def shape(self):
def _entry(self, i, j):
return Add(*[arg._entry(i,j) for arg in self.args])
+ def _eval_transpose(self):
+ from transpose import Transpose
+ return MatAdd(*[Transpose(arg) for arg in self.args])
+
from matmul import MatMul
@@ -100,10 +100,10 @@ def cols(self):
def is_square(self):
return self.rows == self.cols
- def eval_transpose(self):
+ def _eval_transpose(self):
raise NotImplementedError()
- def eval_inverse(self):
+ def _eval_inverse(self):
raise NotImplementedError()
@property
@@ -260,7 +260,10 @@ class Identity(MatrixSymbol):
def __new__(cls, n):
return MatrixSymbol.__new__(cls, "I", n, n)
- def transpose(self):
+ def _eval_transpose(self):
+ return self
+
+ def _eval_inverse(self):
return self
def _entry(self, i, j):
@@ -282,7 +285,8 @@ class ZeroMatrix(MatrixSymbol):
is_ZeroMatrix = True
def __new__(cls, n, m):
return MatrixSymbol.__new__(cls, "0", n, m)
- def transpose(self):
+
+ def _eval_transpose(self):
return ZeroMatrix(self.cols, self.rows)
def _entry(self, i, j):
@@ -84,5 +84,16 @@ def as_coeff_mmul(self):
return coeff, MatMul(*matrices)
+ def _eval_transpose(self):
+ from transpose import Transpose
+ return MatMul(*[Transpose(arg) for arg in self.args[::-1]])
+
+ def _eval_inverse(self):
+ from inverse import Inverse
+ try:
+ return MatMul(*[Inverse(arg) for arg in self.args[::-1]])
+ except ShapeError:
+ raise NotImplementedError("Can not decompose this Inverse")
+
+
from matadd import MatAdd
-from inverse import Inverse
@@ -24,19 +24,10 @@ def __new__(cls, mat):
if not mat.is_Matrix:
return mat
- if isinstance(mat, Transpose):
- return mat.arg
-
- if hasattr(mat, 'transpose'):
- return mat.transpose()
-
- if mat.is_Mul:
- return MatMul(*[Transpose(arg) for arg in mat.args[::-1]])
-
- if mat.is_Add:
- return MatAdd(*[Transpose(arg) for arg in mat.args])
-
- return Basic.__new__(cls, mat)
+ try:
+ return mat._eval_transpose()
+ except (AttributeError, NotImplementedError):
+ return Basic.__new__(cls, mat)
@property
def arg(self):
@@ -49,5 +40,5 @@ def shape(self):
def _entry(self, i, j):
return self.arg._entry(j, i)
-from matmul import MatMul
-from matadd import MatAdd
+ def _eval_transpose(self):
+ return self.arg
@@ -38,3 +38,4 @@ def __setitem__(self, *args):
equals = MatrixBase.equals
is_Identity = MatrixBase.is_Identity
+ _eval_transpose = MatrixBase._eval_transpose
@@ -174,6 +174,9 @@ def _handle_creation_inputs(cls, *args, **kwargs):
return rows, cols, mat
+ def _eval_transpose(self):
+ return self.transpose()
+
def transpose(self):
"""
Matrix transposition.
@@ -816,6 +819,8 @@ def inv(self, method="GE", iszerofunc=_iszero, try_block_diag=False):
# if a new method is added.
raise ValueError("Inversion method unrecognized")
+ def _eval_inverse(self):
+ return self.inv()
def __mathml__(self):
mml = ""

0 comments on commit 9625918

Please sign in to comment.