Skip to content

Commit

Permalink
Implement biased groupwise registration in tps_hc2.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexlee-gk committed Feb 17, 2015
1 parent e73b702 commit 6040e36
Showing 1 changed file with 41 additions and 11 deletions.
52 changes: 41 additions & 11 deletions lfd/registration/hc.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,18 @@ def hc2_obj(x_kld):
grad_kld.append(grad_ld)
return energy, grad_kld

def translate_to_R_plus(x_kld, region):
def translate_to_R_plus(x_kld, region, ret_translation=False):
min_x_kld = np.min([np.min(x_ld) for x_ld in x_kld])
translation = -(min_x_kld - region)
for x_ld in x_kld:
x_ld -= min_x_kld - region
return x_kld
x_ld += translation
if ret_translation:
ret = x_kld, translation
else:
ret = x_kld
return ret

def tps_hc2_obj(z_knd, f_k, reg, rot_reg):
def tps_hc2_obj(z_knd, f_k, reg, rot_reg, y_md=None):
f_k = params_to_multi_tps(z_knd, f_k)

xwarped_kld = []
Expand All @@ -127,7 +132,10 @@ def tps_hc2_obj(z_knd, f_k, reg, rot_reg):
_, d = xwarped_kld[0].shape
xwarped_kld = translate_to_R_plus(xwarped_kld, np.ones(d))

hc2_energy, hc2_grad_kld = hc2_obj(xwarped_kld)
if y_md is None:
hc2_energy, hc2_grad_kld = hc2_obj(xwarped_kld)
else:
hc2_energy, hc2_grad_kld = hc2_obj(xwarped_kld + [y_md])
energy = hc2_energy
grad_knd = []
for f, hc2_grad_ld in zip(f_k, hc2_grad_kld):
Expand All @@ -141,10 +149,13 @@ def tps_hc2_obj(z_knd, f_k, reg, rot_reg):
grad_knd = np.concatenate(grad_knd)
return energy, grad_knd

def tps_hc2(x_kld, ctrl_knd=None, f_init_k=None, opt_iter=100, reg=settings.REG[1], rot_reg=settings.ROT_REG, callback=None):
def tps_hc2(x_kld, y_md=None, ctrl_knd=None, f_init_k=None, opt_iter=100, reg=settings.REG[1], rot_reg=settings.ROT_REG, callback=None):
if f_init_k is None:
if ctrl_knd is None:
ctrl_knd = x_kld
else:
if len(ctrl_knd) != len(x_kld):
raise ValueError("The number of control points in ctrl_knd is different from the number of point sets in x_kld")
f_k = []
for x_ld, ctrl_nd in zip(x_kld, ctrl_knd):
f = ThinPlateSpline(x_ld, ctrl_nd)
Expand All @@ -153,15 +164,34 @@ def tps_hc2(x_kld, ctrl_knd=None, f_init_k=None, opt_iter=100, reg=settings.REG[
if len(f_init_k) != len(x_kld):
raise ValueError("The number of ThinPlateSplines in f_init_k is different from the number of point sets in x_kld")
f_k = f_init_k

# translate problem by trans_d. At the end, problem needs to be translated back
if y_md is not None:
# shift y_md and f by trans_d so that y_md is in R plus for stability
_, d = y_md.shape
[y_trans_md], trans_d = translate_to_R_plus([y_md.copy()], np.ones(d), ret_translation=True) # copy() because it is translated in place
else:
d = x_kld[0].shape[1]
trans_d = np.zeros(d)
y_trans_md = y_md
for f in f_k:
f.trans_g += trans_d

z_knd = multi_tps_to_params(f_k)

def opt_callback(z_knd):
callback(params_to_multi_tps(z_knd, f_k))
# print tps_hc2_obj(z_knd, f_k, reg, rot_reg)[0]

res = so.fmin_l_bfgs_b(tps_hc2_obj, z_knd, None, args=(f_k, reg, rot_reg), maxfun=opt_iter, callback=opt_callback if callback is not None else None)
params_to_multi_tps(z_knd, f_k)
for f in f_k:
f.trans_g -= trans_d
callback(f_k, y_md)
for f in f_k:
f.trans_g += trans_d
print tps_hc2_obj(z_knd, f_k, reg, rot_reg, y_md=y_trans_md)[0]

res = so.fmin_l_bfgs_b(tps_hc2_obj, z_knd, None, args=(f_k, reg, rot_reg, y_trans_md), maxfun=opt_iter, callback=opt_callback if callback is not None else None)
z_knd = res[0]

f_k = params_to_multi_tps(z_knd, f_k)

for f in f_k:
f.trans_g -= trans_d
return f_k

0 comments on commit 6040e36

Please sign in to comment.