Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

ENH: linalg/expm: merge expm() algorithm improvement

Merge branch 'rngantner/expm_rewrite'
  • Loading branch information...
commit 666b89a87d62a01699186e264c959aceda799a3e 2 parents 39294b7 + cbebb75
@pv pv authored
View
1  THANKS.txt
@@ -100,6 +100,7 @@ Denis Laxalde for the unified interface to minimizers in scipy.optimize.
David Fong for the sparse LSMR solver.
Andreas Hilboll for wrapping FITPACK's spgrid.f
Andrew Schein for improving the numerical precision of norm.logcdf().
+Robert Gantner for improving expm() implementation.
Institutions
View
134 scipy/linalg/matfuncs.py
@@ -8,7 +8,7 @@
from numpy import asarray, Inf, dot, floor, eye, diag, exp, \
product, logical_not, ravel, transpose, conjugate, \
cast, log, ogrid, imag, real, absolute, amax, sign, \
- isfinite, sqrt, identity, single
+ isfinite, sqrt, identity, single, ceil, log2
from numpy import matrix as mat
import numpy as np
@@ -19,56 +19,124 @@
from decomp import eig
from decomp_svd import orth, svd
from decomp_schur import schur, rsf2csf
+import warnings
eps = np.finfo(float).eps
feps = np.finfo(single).eps
-def expm(A, q=7):
+def expm(A, q=False):
"""Compute the matrix exponential using Pade approximation.
-
+
Parameters
----------
A : array, shape(M,M)
Matrix to be exponentiated
- q : integer
- Order of the Pade approximation
Returns
-------
expA : array, shape(M,M)
Matrix exponential of A
+ References
+ ----------
+ N. J. Higham,
+ "The Scaling and Squaring Method for the Matrix Exponential Revisited",
+ SIAM. J. Matrix Anal. & Appl. 26, 1179 (2005).
+
"""
+ if q: warnings.warn("argument q=... in scipy.linalg.expm is deprecated.")
A = asarray(A)
-
- # Scale A so that norm is < 1/2
- nA = norm(A,Inf)
- if nA==0:
- return identity(len(A), A.dtype.char)
- from numpy import log2
- val = log2(nA)
- e = int(floor(val))
- j = max(0,e+1)
- A = A / 2.0**j
-
- # Pade Approximation for exp(A)
- X = A
- c = 1.0/2
- N = eye(*A.shape) + c*A
- D = eye(*A.shape) - c*A
- for k in range(2,q+1):
- c = c * (q-k+1) / (k*(2*q-k+1))
- X = dot(A,X)
- cX = c*X
- N = N + cX
- if not k % 2:
- D = D + cX;
+ A_L1 = norm(A,1)
+ n_squarings = 0
+
+ if A.dtype == 'float64' or A.dtype == 'complex128':
+ if A_L1 < 1.495585217958292e-002:
+ U,V = _pade3(A)
+ elif A_L1 < 2.539398330063230e-001:
+ U,V = _pade5(A)
+ elif A_L1 < 9.504178996162932e-001:
+ U,V = _pade7(A)
+ elif A_L1 < 2.097847961257068e+000:
+ U,V = _pade9(A)
+ else:
+ maxnorm = 5.371920351148152
+ n_squarings = max(0, int(ceil(log2(A_L1 / maxnorm))))
+ A = A / 2**n_squarings
+ U,V = _pade13(A)
+ elif A.dtype == 'float32' or A.dtype == 'complex64':
+ if A_L1 < 4.258730016922831e-001:
+ U,V = _pade3(A)
+ elif A_L1 < 1.880152677804762e+000:
+ U,V = _pade5(A)
else:
- D = D - cX;
- F = solve(D,N)
- for k in range(1,j+1):
- F = dot(F,F)
- return F
+ maxnorm = 3.925724783138660
+ n_squarings = max(0, int(ceil(log2(A_L1 / maxnorm))))
+ A = A / 2**n_squarings
+ U,V = _pade7(A)
+ else:
+ raise ValueError("invalid type: "+str(A.dtype))
+
+ P = U + V # p_m(A) : numerator
+ Q = -U + V # q_m(A) : denominator
+ R = solve(Q,P)
+ # squaring step to undo scaling
+ for i in range(n_squarings):
+ R = dot(R,R)
+
+ return R
+
+# implementation of Pade approximations of various degree using the algorithm presented in [Higham 2005]
+
+def _pade3(A):
+ b = (120., 60., 12., 1.)
+ ident = eye(*A.shape, dtype=A.dtype)
+ A2 = dot(A,A)
+ U = dot(A , (b[3]*A2 + b[1]*ident))
+ V = b[2]*A2 + b[0]*ident
+ return U,V
+
+def _pade5(A):
+ b = (30240., 15120., 3360., 420., 30., 1.)
+ ident = eye(*A.shape, dtype=A.dtype)
+ A2 = dot(A,A)
+ A4 = dot(A2,A2)
+ U = dot(A, b[5]*A4 + b[3]*A2 + b[1]*ident)
+ V = b[4]*A4 + b[2]*A2 + b[0]*ident
+ return U,V
+
+def _pade7(A):
+ b = (17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.)
+ ident = eye(*A.shape, dtype=A.dtype)
+ A2 = dot(A,A)
+ A4 = dot(A2,A2)
+ A6 = dot(A4,A2)
+ U = dot(A, b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
+ V = b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
+ return U,V
+
+def _pade9(A):
+ b = (17643225600., 8821612800., 2075673600., 302702400., 30270240.,
+ 2162160., 110880., 3960., 90., 1.)
+ ident = eye(*A.shape, dtype=A.dtype)
+ A2 = dot(A,A)
+ A4 = dot(A2,A2)
+ A6 = dot(A4,A2)
+ A8 = dot(A6,A2)
+ U = dot(A, b[9]*A8 + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
+ V = b[8]*A8 + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
+ return U,V
+
+def _pade13(A):
+ b = (64764752532480000., 32382376266240000., 7771770303897600.,
+ 1187353796428800., 129060195264000., 10559470521600., 670442572800.,
+ 33522128640., 1323241920., 40840800., 960960., 16380., 182., 1.)
+ ident = eye(*A.shape, dtype=A.dtype)
+ A2 = dot(A,A)
+ A4 = dot(A2,A2)
+ A6 = dot(A4,A2)
+ U = dot(A,dot(A6, b[13]*A6 + b[11]*A4 + b[9]*A2) + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
+ V = dot(A6, b[12]*A6 + b[10]*A4 + b[8]*A2) + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
+ return U,V
def expm2(A):
"""Compute the matrix exponential using eigenvalue decomposition.
View
25 scipy/linalg/tests/test_matfuncs.py
@@ -6,8 +6,10 @@
"""
-from numpy import array, identity, dot, sqrt
-from numpy.testing import TestCase, run_module_suite, assert_array_almost_equal
+import numpy as np
+from numpy import array, identity, dot, sqrt, double, exp, random
+from numpy.testing import TestCase, run_module_suite, assert_array_almost_equal, \
+ assert_array_almost_equal_nulp
from scipy.linalg import signm, logm, sqrtm, expm, expm2, expm3
@@ -102,5 +104,24 @@ def test_consistency(self):
assert_array_almost_equal(expm(a), expm2(a))
assert_array_almost_equal(expm(a), expm3(a))
+ def test_padecases_dtype(self):
+ for dtype in [np.float32, np.float64, np.complex64, np.complex128]:
+ # test double-precision cases
+ for scale in [1e-2, 1e-1, 5e-1, 1, 10]:
+ a = scale * identity(3, dtype=dtype)
+ e = exp(scale) * identity(3, dtype=dtype)
+ assert_array_almost_equal_nulp(expm(a), e, nulp=100)
+
+ def test_logm_consistency(self):
+ random.seed(1234)
+ for dtype in [np.float32, np.float64, np.complex64, np.complex128]:
+ for n in range(1, 10):
+ for scale in [1e-4, 1e-3, 1e-2, 1e-1, 1, 1e1, 1e2]:
+ # make logm(a) be of a given scale
+ a = (identity(n) + random.rand(n, n) * scale).astype(dtype)
+ if np.iscomplexobj(a):
+ a = a + 1j * random.rand(n, n) * scale
+ assert_array_almost_equal(expm(logm(a)), a)
+
if __name__ == "__main__":
run_module_suite()
Please sign in to comment.
Something went wrong with that request. Please try again.