Skip to content

Commit

Permalink
registration: make use of np.solve(), it is faster
Browse files Browse the repository at this point in the history
  • Loading branch information
rjw57 committed Mar 28, 2014
1 parent 62242c5 commit a720d22
Showing 1 changed file with 5 additions and 20 deletions.
25 changes: 5 additions & 20 deletions dtcwt/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,34 +226,19 @@ def solvetransform(Qtilde_vec):
# Want to find a = -Q^{-1} q => Qa = -q
# The naive way would be: return -np.linalg.inv(Q).dot(q)
# A less naive way would be: return np.linalg.solve(Q, -q)
#
# Recall that, given Q is symmetric, we can decompose it at Q = U L U^T with
# diagonal L and orthogonal U. Hence:
#
# U L U^T a = -q => a = - U L^-1 U^T q
#
# An even better way is thus to use the specialised eigenvector
# calculation for symmetric matrices:

# NumPy >1.8 directly supports using the last two dimensions as a matrix. IF
# we get a LinAlgError, we assume that we need to fall-back to a NumPy 1.7
# approach which is *significantly* slower.
try:
l, U = np.linalg.eigh(Q, 'U')
rv = np.linalg.solve(Q, -q)
except np.linalg.LinAlgError:
# Try the slower fallback
l = np.zeros(Qtilde_vec.shape[:-1] + (6,))
U = np.zeros(Qtilde_vec.shape[:-1] + (6,6))
rv = np.zeros(Qtilde_vec.shape[:-1] + (6,))
for idx in itertools.product(*list(xrange(s) for s in Qtilde_vec.shape[:-1])):
l[idx], U[idx] = np.linalg.eigh(Q[idx], 'U')
rv[idx] = np.linalg.solve(Q[idx], -q[idx])

# Now we have some issue here. If Qtilde_vec is a straightforward vector
# then we can just return U.dot((U.T.dot(-q))/l). However if Qtilde_vec is
# an array of vectors then the straightforward dot product won't work. We
# want to perform matrix multiplication on stacked matrices.
return dtcwt.utils.stacked_2d_matrix_vector_prod(
U, dtcwt.utils.stacked_2d_vector_matrix_prod(-q, U) / l
)
return rv

def normsamplehighpass(Yh, xs, ys, method=None):
"""
Expand Down Expand Up @@ -362,7 +347,7 @@ def estimatereg(source, reference, regshape=None):

qts = np.zeros(avecs.shape[:2] + all_qts[0].shape[2:])
for x in all_qts:
qts += dtcwt.sampling.rescale(_boxfilter(x, 3), avecs.shape[:2], method='nearest')
qts += dtcwt.sampling.rescale(_boxfilter(x, 3), avecs.shape[:2], method='bilinear')

avecs += solvetransform(qts)

Expand Down

0 comments on commit a720d22

Please sign in to comment.