Skip to content

Commit

Permalink
Abstract covariance objective and fix its 1/k scaling.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexlee-gk committed Feb 17, 2015
1 parent 935d1a4 commit cd754f9
Showing 1 changed file with 55 additions and 26 deletions.
81 changes: 55 additions & 26 deletions lfd/registration/tps_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,16 +248,63 @@ def pairwise_tps_l2_obj(z_knd, f_k, y_md, rad, reg, rot_reg):
grad_knd = np.concatenate(grad_knd)
return energy, grad_knd

def multi_tps_l2_obj(z_knd, f_k, L_ktkn, y_md, p_ktd, rad, reg, rot_reg, cov_coef):
def compute_sum_var_matrix(f_k, p_ktd):
"""Computes the kt by kn matrix L_ktkn for calculating the sum of variances.
The sum of variances is given by
(1/k) * np.sum(np.square(L_ktkn.dot(z_knd.reshape((-1,d)))))
"""
QN_ktn = []
for f, p_td in zip(f_k, p_ktd):
QN_tn = f.compute_transform_grad(p_td)
QN_ktn.append(QN_tn)

k, t, _ = p_ktd.shape
L_ktkn = []
for j in range(t):
QN_1kn = []
for QN_tn in QN_ktn:
QN_1kn.append(QN_tn[j,:])
QN_1kn = np.concatenate(QN_1kn)
i = 0
for QN_tn in QN_ktn:
_, n = QN_tn.shape
L_1kn = (-1/k) * QN_1kn
L_1kn[i:i+n] += QN_tn[j,:]
L_ktkn.append(L_1kn)
i += n
L_ktkn = np.array(L_ktkn)
return L_ktkn

def tps_cov_obj(z_knd, f_k, p_ktd, L_ktkn=None):
f_k = params_to_multi_tps(z_knd, f_k)

pw_tps_l2_energy, pw_tps_l2_grad_knd = pairwise_tps_l2_obj(z_knd, f_k, y_md, rad, reg, rot_reg)
if L_ktkn is None:
L_ktkn = compute_sum_var_matrix(f_k, p_ktd)

_, d = y_md.shape
k, t, d = p_ktd.shape
Lz_ktd = L_ktkn.dot(z_knd.reshape((-1,d)))
cov_energy = np.sum(np.square(Lz_ktd))
cov_grad_knd = 2 * L_ktkn.T.dot(Lz_ktd).reshape(-1)
energy = (1/k) * np.sum(np.square(Lz_ktd))
grad_knd = (1/k) * 2 * L_ktkn.T.dot(Lz_ktd).reshape(-1)

# fp_ktd = []
# for f, p_td in zip(f_k, p_ktd):
# fp_td = f.transform_points(p_td)
# fp_ktd.append(fp_td)
# fp_ktd = np.array(fp_ktd)
# energy2 = 0
# for j in range(t):
# fp_kd = fp_ktd[:,j,:]
# energy2 += (1/k) * np.trace((fp_kd - fp_kd.mean(axis=0)).T.dot(fp_kd - fp_kd.mean(axis=0)))
# print "energy cov equal?", np.allclose(energy, energy2)

return energy, grad_knd

def pairwise_tps_l2_cov_obj(z_knd, f_k, y_md, p_ktd, rad, reg, rot_reg, cov_coef, L_ktkn=None):
f_k = params_to_multi_tps(z_knd, f_k)

pw_tps_l2_energy, pw_tps_l2_grad_knd = pairwise_tps_l2_obj(z_knd, f_k, y_md, rad, reg, rot_reg)
cov_energy, cov_grad_knd = tps_cov_obj(z_knd, f_k, p_ktd, L_ktkn=L_ktkn)
energy = pw_tps_l2_energy + cov_coef * cov_energy
grad_knd = pw_tps_l2_grad_knd + cov_coef * cov_grad_knd

Expand All @@ -279,7 +326,7 @@ def params_to_multi_tps(z_knd, f_k):
i += n
return f_k

def multi_tps_l2(x_kld, y_md, p_ktd, ctrl_knd=None,
def pairwise_tps_l2_cov(x_kld, y_md, p_ktd, ctrl_knd=None,
n_iter=settings.N_ITER, opt_iter=100,
reg_init=settings.REG[0], reg_final=settings.REG[1],
rad_init=settings.RAD[0], rad_final=settings.RAD[1],
Expand All @@ -292,37 +339,19 @@ def multi_tps_l2(x_kld, y_md, p_ktd, ctrl_knd=None,

# intitalize z from independent optimizations
f_k = []
QN_ktn = []
for (x_ld, p_td, ctrl_nd) in zip(x_kld, p_ktd, ctrl_knd):
n, d = ctrl_nd.shape
f = tps_l2(x_ld, y_md, ctrl_nd=ctrl_nd, n_iter=n_iter, opt_iter=opt_iter, reg_init=reg_init, reg_final=reg_final, rad_init=rad_init, rad_final=rad_final, rot_reg=rot_reg, callback=callback)
f_k.append(f)
QN_tn = f.compute_transform_grad(p_td)
QN_ktn.append(QN_tn)
z_knd = multi_tps_to_params(f_k)

# put together matrix for computing sum of variances
# the sum of variances is given by np.sum(np.square(L_ktkn.dot(z_knd.reshape((-1,d)))))
k, t, _ = p_ktd.shape
L_ktkn = []
for j in range(t):
QN_1kn = []
for QN_tn in QN_ktn:
QN_1kn.append(QN_tn[j,:])
QN_1kn = np.concatenate(QN_1kn)
i = 0
for QN_tn in QN_ktn:
_, n = QN_tn.shape
L_1kn = (-1/k) * QN_1kn
L_1kn[i:i+n] += QN_tn[j,:]
L_ktkn.append(L_1kn)
i += n
L_ktkn = (1/k) * np.array(L_ktkn)
L_ktkn = compute_sum_var_matrix(f_k, p_ktd)

if multi_callback is not None:
multi_callback(f_k, y_md, p_ktd)

res = so.fmin_l_bfgs_b(multi_tps_l2_obj, z_knd, None, args=(f_k, L_ktkn, y_md, p_ktd, rad_final, reg_final, rot_reg, cov_coef), maxfun=opt_iter)
res = so.fmin_l_bfgs_b(pairwise_tps_l2_cov_obj, z_knd, None, args=(f_k, y_md, p_ktd, rad_final, reg_final, rot_reg, cov_coef, L_ktkn), maxfun=opt_iter)
z_knd = res[0]

f_k = params_to_multi_tps(z_knd, f_k)
Expand Down

0 comments on commit cd754f9

Please sign in to comment.