Skip to content

Commit

Permalink
registration: add customisation of level selection
Browse files Browse the repository at this point in the history
  • Loading branch information
rjw57 committed Oct 16, 2014
1 parent 018ace2 commit 722961f
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions dtcwt/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def warptransform(t, avecs, levels, method=None):
# Clone the transform
return dtcwt.numpy.Pyramid(t.lowpass, tuple(warped_highpasses), t.scales)

def estimatereg(source, reference, regshape=None):
def estimatereg(source, reference, regshape=None, levels=None):
"""
Estimate registration from which will map *source* to *reference*.
Expand All @@ -318,6 +318,10 @@ def estimatereg(source, reference, regshape=None):
Use the :py:func:`velocityfield` function to convert the return value from
this function into a velocity field.
If not-`None`, *levels* is a sequence of sequences of 0-based level indices
to use when calculating the registration. If `None` then a default set of
levels are used.
"""
# Extract number of levels and shape of level 4 (i.e. index 3) subband
nlevels = len(source.highpasses)
Expand All @@ -329,11 +333,19 @@ def estimatereg(source, reference, regshape=None):
# Initialise matrix of 'a' vectors
avecs = np.zeros(avecs_shape)

if levels is None:
levels = []
levels.append(list(x for x in xrange(nlevels-1, nlevels-3, -1) if x>=0))
for s in np.arange(nlevels-1, 0, -0.5):
refine_levels = list(int(np.floor(s))-x for x in range(2) if s-x >= 2)
if len(refine_levels) < 2:
continue
levels.append(refine_levels)

# Compute initial global transform
levels = list(x for x in xrange(nlevels-1, nlevels-3, -1) if x>=0)
Qt_mats = list(
np.sum(np.sum(x, axis=0), axis=0)
for x in qtildematrices(source, reference, levels)
for x in qtildematrices(source, reference, levels[0])
)
Qt = np.sum(Qt_mats, axis=0)

Expand All @@ -342,16 +354,12 @@ def estimatereg(source, reference, regshape=None):
avecs[:,:,idx] = a[idx]

# Refine estimate
for s in np.arange(nlevels-1, 0, -0.5):
levels = list(int(np.floor(s))-x for x in range(2) if s-x >= 2)
if len(levels) < 2:
continue

for est_levels in levels[1:]:
# Warp the levels we'll be looking at with the current best-guess transform
warped = warptransform(source, avecs, levels, method='bilinear')
warped = warptransform(source, avecs, est_levels, method='bilinear')

# Rescale and sample all the Qtilde matrix results
all_qts = qtildematrices(warped, reference, levels)
all_qts = qtildematrices(warped, reference, est_levels)
if all_qts is None or len(all_qts) < 1:
continue

Expand Down

0 comments on commit 722961f

Please sign in to comment.