Skip to content

Commit

Permalink
Refactored _padeXX() functs to use sparse or dense matrices.
Browse files Browse the repository at this point in the history
  • Loading branch information
scopatz committed Aug 14, 2012
1 parent c929e84 commit db58855
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 87 deletions.
80 changes: 55 additions & 25 deletions scipy/linalg/matfuncs.py
Expand Up @@ -12,6 +12,9 @@
from numpy import matrix as mat
import numpy as np

from scipy.sparse.base import isspmatrix
from scipy.sparse.construct import eye as speye

# Local imports
from misc import norm
from basic import solve, inv
Expand Down Expand Up @@ -87,56 +90,83 @@ def expm(A, q=False):
return R

# implementation of Pade approximations of various degree using the algorithm presented in [Higham 2005]

# These should apply to both dense and sparse matricies.
def _pade3(A):
b = (120., 60., 12., 1.)
ident = eye(A.shape[0], A.shape[1], dtype=A.dtype)
A2 = dot(A,A)
U = dot(A , (b[3]*A2 + b[1]*ident))
if isspmatrix(A):
ident = speye(A.shape[0], A.shape[1], dtype=A.dtype, format=A.format)
else:
ident = eye(A.shape[0], A.shape[1], dtype=A.dtype)
A2 = A.dot(A)
U = A.dot(b[3]*A2 + b[1]*ident)
V = b[2]*A2 + b[0]*ident
return U,V

def _pade5(A):
b = (30240., 15120., 3360., 420., 30., 1.)
ident = eye(A.shape[0], A.shape[1], dtype=A.dtype)
A2 = dot(A,A)
A4 = dot(A2,A2)
U = dot(A, b[5]*A4 + b[3]*A2 + b[1]*ident)
if isspmatrix(A):
ident = speye(A.shape[0], A.shape[1], dtype=A.dtype, format=A.format)
else:
ident = eye(A.shape[0], A.shape[1], dtype=A.dtype)
A2 = A.dot(A)
A4 = A2.dot(A2)
U = A.dot(b[5]*A4 + b[3]*A2 + b[1]*ident)
V = b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V

def _pade5(A):
b = (30240., 15120., 3360., 420., 30., 1.)
if isspmatrix(A):
ident = speye(A.shape[0], A.shape[1], dtype=A.dtype, format=A.format)
else:
ident = eye(A.shape[0], A.shape[1], dtype=A.dtype)
A2 = A.dot(A)
A4 = A2.dot(A2)
U = A.dot(b[5]*A4 + b[3]*A2 + b[1]*ident)
V = b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V

def _pade7(A):
b = (17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.)
ident = eye(A.shape[0], A.shape[1], dtype=A.dtype)
A2 = dot(A,A)
A4 = dot(A2,A2)
A6 = dot(A4,A2)
U = dot(A, b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
if isspmatrix(A):
ident = speye(A.shape[0], A.shape[1], dtype=A.dtype, format=A.format)
else:
ident = eye(A.shape[0], A.shape[1], dtype=A.dtype)
A2 = A.dot(A)
A4 = A2.dot(A2)
A6 = A4.dot(A2)
U = A.dot(b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
V = b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V

def _pade9(A):
b = (17643225600., 8821612800., 2075673600., 302702400., 30270240.,
2162160., 110880., 3960., 90., 1.)
ident = eye(A.shape[0], A.shape[1], dtype=A.dtype)
A2 = dot(A,A)
A4 = dot(A2,A2)
A6 = dot(A4,A2)
A8 = dot(A6,A2)
U = dot(A, b[9]*A8 + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
if isspmatrix(A):
ident = speye(A.shape[0], A.shape[1], dtype=A.dtype, format=A.format)
else:
ident = eye(A.shape[0], A.shape[1], dtype=A.dtype)
A2 = A.dot(A)
A4 = A2.dot(A2)
A6 = A4.dot(A2)
A8 = A6.dot(A2)
U = A.dot(b[9]*A8 + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
V = b[8]*A8 + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V

def _pade13(A):
b = (64764752532480000., 32382376266240000., 7771770303897600.,
1187353796428800., 129060195264000., 10559470521600., 670442572800.,
33522128640., 1323241920., 40840800., 960960., 16380., 182., 1.)
ident = eye(A.shape[0], A.shape[1], dtype=A.dtype)
A2 = dot(A,A)
A4 = dot(A2,A2)
A6 = dot(A4,A2)
U = dot(A,dot(A6, b[13]*A6 + b[11]*A4 + b[9]*A2) + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
V = dot(A6, b[12]*A6 + b[10]*A4 + b[8]*A2) + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
if isspmatrix(A):
ident = speye(A.shape[0], A.shape[1], dtype=A.dtype, format=A.format)
else:
ident = eye(A.shape[0], A.shape[1], dtype=A.dtype)
A2 = A.dot(A)
A4 = A2.dot(A2)
A6 = A4.dot(A2)
U = A.dot(A6.dot(b[13]*A6 + b[11]*A4 + b[9]*A2) + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
V = A6.dot(b[12]*A6 + b[10]*A4 + b[8]*A2) + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V

def expm2(A):
Expand Down
5 changes: 5 additions & 0 deletions scipy/linalg/misc.py
@@ -1,7 +1,12 @@
import numpy as np
from numpy import eye
from numpy.linalg import LinAlgError
import fblas

from scipy.sparse.base import isspmatrix
from scipy.sparse.construct import eye as speye


__all__ = ['LinAlgError', 'norm']

_nrm2_prefix = {'f' : 's', 'F': 'sc', 'D': 'dz'}
Expand Down
63 changes: 2 additions & 61 deletions scipy/sparse/base.py
Expand Up @@ -639,6 +639,8 @@ def expm(self):
--------
scipy.linalg.expm
"""
from scipy.linalg.matfuncs import _pade3, _pade5, _pade7, _pade9, _pade13

A_L1 = max(abs(self).sum(axis=0).flat)
n_squarings = 0

Expand Down Expand Up @@ -691,67 +693,6 @@ def inv(self):
return selfinv



def _pade3(A):
from construct import eye
b = (120., 60., 12., 1.)
ident = eye(A.shape[0], A.shape[1], dtype=A.dtype, format=A.format)
A2 = A.dot(A)
U = A.dot(b[3]*A2 + b[1]*ident)
V = b[2]*A2 + b[0]*ident
return U,V

def _pade5(A):
from construct import eye
b = (30240., 15120., 3360., 420., 30., 1.)
ident = eye(A.shape[0], A.shape[1], dtype=A.dtype, format=A.format)
A2 = A.dot(A)
A4 = A2.dot(A2)
U = A.dot(b[5]*A4 + b[3]*A2 + b[1]*ident)
V = b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V

def _pade7(A):
from construct import eye
b = (17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.)
ident = eye(A.shape[0], A.shape[1], dtype=A.dtype, format=A.format)
A2 = A.dot(A)
A4 = A2.dot(A2)
A6 = A4.dot(A2)
U = A.dot(b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
V = b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V

def _pade9(A):
from construct import eye
b = (17643225600., 8821612800., 2075673600., 302702400., 30270240.,
2162160., 110880., 3960., 90., 1.)
ident = eye(A.shape[0], A.shape[1], dtype=A.dtype, format=A.format)
A2 = A.dot(A)
A4 = A2.dot(A2)
A6 = A4.dot(A2)
A8 = A6.dot(A2)
U = A.dot(b[9]*A8 + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
V = b[8]*A8 + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V

def _pade13(A):
from construct import eye
b = (64764752532480000., 32382376266240000., 7771770303897600.,
1187353796428800., 129060195264000., 10559470521600., 670442572800.,
33522128640., 1323241920., 40840800., 960960., 16380., 182., 1.)
ident = eye(A.shape[0], A.shape[1], dtype=A.dtype, format=A.format)
A2 = A.dot(A)
A4 = A2.dot(A2)
A6 = A4.dot(A2)
U = A.dot(A6.dot(b[13]*A6 + b[11]*A4 + b[9]*A2) + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
V = A6.dot(b[12]*A6 + b[10]*A4 + b[8]*A2) + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V





def isspmatrix(x):
return isinstance(x, spmatrix)

Expand Down
2 changes: 1 addition & 1 deletion scipy/sparse/linalg/dsolve/linsolve.py
Expand Up @@ -62,7 +62,7 @@ def spsolve(A, b, permc_spec=None, use_umfpack=True):
msg += "or matrix whose shape matches lhs (%s)" % (A.shape,)
raise ValueError(msg)
elif isspmatrix(b) and not (isspmatrix_csc(b) or isspmatrix_csr(b)):
B = csc_matrix(b)
b = csc_matrix(b)
warn('solve requires b be CSC or CSR matrix format', SparseEfficiencyWarning)

A.sort_indices()
Expand Down

0 comments on commit db58855

Please sign in to comment.