Skip to content

Commit

Permalink
Implement groupwise hc2 with covariance.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexlee-gk committed Feb 17, 2015
1 parent 42f5b0e commit fde606f
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions lfd/registration/hc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import settings
import numpy as np
import scipy.optimize as so
import tps_experimental
from tps_experimental import ThinPlateSpline, multi_tps_to_params, params_to_multi_tps

def analymin(x, y, beta):
Expand Down Expand Up @@ -198,3 +199,51 @@ def opt_callback(z_knd):
for f in f_k:
f.trans_g -= trans_d
return f_k

def groupwise_tps_hc2_cov_obj(z_knd, f_k, p_ktd, reg, rot_reg, cov_coef, y_md=None, L_ktkn=None):
f_k = params_to_multi_tps(z_knd, f_k)

gw_tps_hc2_energy, gw_tps_hc2_grad_knd = groupwise_tps_hc2_obj(z_knd, f_k, reg, rot_reg, y_md=y_md)
cov_energy, cov_grad_knd = tps_experimental.tps_cov_obj(z_knd, f_k, p_ktd, L_ktkn=L_ktkn)
energy = gw_tps_hc2_energy + cov_coef * cov_energy
grad_knd = gw_tps_hc2_grad_knd + cov_coef * cov_grad_knd

return energy, grad_knd

def groupwise_tps_hc2_cov(x_kld, p_ktd, y_md=None, ctrl_knd=None, f_init_k=None, opt_iter=100, reg=settings.REG[1], rot_reg=settings.ROT_REG, cov_coef=settings.COV_COEF, callback=None, multi_callback=None):
# intitalize z from independent optimizations from the one without covariance
f_k = groupwise_tps_hc2(x_kld, y_md=y_md, ctrl_knd=ctrl_knd, f_init_k=f_init_k, opt_iter=opt_iter, reg=reg, rot_reg=rot_reg, callback=callback)

# translate problem by trans_d. At the end, problem needs to be translated back
if y_md is not None:
# shift y_md and f by trans_d so that y_md is in R plus for stability
_, d = y_md.shape
[y_trans_md], trans_d = translate_to_R_plus([y_md.copy()], np.ones(d), ret_translation=True) # copy() because it is translated in place
else:
d = x_kld[0].shape[1]
trans_d = np.zeros(d)
y_trans_md = y_md
for f in f_k:
f.trans_g += trans_d

z_knd = multi_tps_to_params(f_k)

# put together matrix for computing sum of variances
L_ktkn = tps_experimental.compute_sum_var_matrix(f_k, p_ktd)

def opt_callback(z_knd):
params_to_multi_tps(z_knd, f_k)
for f in f_k:
f.trans_g -= trans_d
multi_callback(f_k, p_ktd, y_md)
for f in f_k:
f.trans_g += trans_d
# print groupwise_tps_hc2_cov_obj(z_knd, f_k, p_ktd, reg, rot_reg, cov_coef, y_md=y_trans_md, L_ktkn=L_ktkn)[0]

res = so.fmin_l_bfgs_b(groupwise_tps_hc2_cov_obj, z_knd, None, args=(f_k, p_ktd, reg, rot_reg, cov_coef, y_trans_md, L_ktkn), maxfun=opt_iter, callback=opt_callback if multi_callback is not None else None)
z_knd = res[0]

f_k = params_to_multi_tps(z_knd, f_k)
for f in f_k:
f.trans_g -= trans_d
return f_k

0 comments on commit fde606f

Please sign in to comment.