Skip to content

Commit

Permalink
Finish implementing TpsnRegistration and its factory. Fix bug in obje…
Browse files Browse the repository at this point in the history
…ctive2.
  • Loading branch information
alexlee-gk committed Feb 25, 2015
1 parent ab4a43f commit 5d32e02
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 34 deletions.
76 changes: 44 additions & 32 deletions lfd/registration/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,24 +93,23 @@ class TpsL2Registration(Registration):


class TpsnRpmRegistration(Registration):
def __init__(self, demo, test_scene_state, f, corr, x_ld, u_rd, z_rd, y_md, v_sd, z_sd, rad, radn, bend_coef, rot_coef):
super(TpsRpmRegistration, self).__init__(demo, test_scene_state, f, corr)
self.x_ld = x_ld
self.u_rd = u_rd
self.z_rd = z_rd
self.y_md = y_md
self.v_sd = v_sd
self.z_sd = z_sd
def __init__(self, demo, test_scene_state, f, corr_lm, corr_rs, rad, radn, bend_coef, rot_coef):
super(TpsnRpmRegistration, self).__init__(demo, test_scene_state, f, corr_lm)
self.x_ld = demo.scene_state.cloud[:,:3]
self.u_rd = demo.scene_state.normals
self.z_rd = demo.scene_state.sites
self.y_md = test_scene_state.cloud[:,:3]
self.v_sd = test_scene_state.normals
self.z_sd = test_scene_state.sites
self.corr_lm = corr_lm
self.corr_rs = corr_rs
self.rad = rad
self.radn = radn
self.bend_coef = bend_coef
self.rot_coef = rot_coef
self.rot_coef = rot_coef

def get_objective(self):
x_nd = self.demo.scene_state.cloud[:,:3]
y_md = self.test_scene_state.cloud[:,:3]
# TODO: fill x_ld, u_rd, z_rd, y_md, v_sd, z_sd
cost = self.get_objective2(x_ld, u_rd, z_rd, y_md, v_sd, z_sd, self.f, self.corr_lm, self.corr_rs, self.rad, self.radn, self.bend_coef, self.rot_coef)
cost = self.get_objective2(self.x_ld, self.u_rd, self.z_rd, self.y_md, self.v_sd, self.z_sd, self.f, self.corr_lm, self.corr_rs, self.rad, self.radn, self.bend_coef, self.rot_coef)
return cost

@staticmethod
Expand Down Expand Up @@ -157,9 +156,9 @@ def get_objective2(x_ld, u_rd, z_rd, y_md, v_sd, z_sd, f, corr_lm, corr_rs, rad,
# normal entropy
corr_rs = np.reshape(corr_rs, (1,-1))
nz_corr_rs = corr_rs[corr_rs != 0]
site_dist_rs = np.reshape(site_dist_rs, (1,-1))
nz_site_dist_rs = site_dist_rs[corr_rs != 0]
cost[6] = (2*radn / r) * (nz_corr_rs * np.log(nz_corr_rs / nz_site_dist_rs)).sum()
prior_prob_rs = np.reshape(prior_prob_rs, (1,-1))
nz_prior_prob_rs = prior_prob_rs[corr_rs != 0]
cost[6] = (2*radn / r) * (nz_corr_rs * np.log(nz_corr_rs / nz_prior_prob_rs)).sum()
cost[7] = -(2*radn / r) * nz_corr_rs.sum()
return cost

Expand Down Expand Up @@ -603,34 +602,47 @@ class TpsnRpmRegistrationFactory(RegistrationFactory):
TPS-RPM using normals information
"""
def __init__(self, demos=None,
n_iter=settings.N_ITER, em_iter=settings.EM_ITER,
reg_init=settings.REG[0], reg_final=settings.REG[1],
rad_init=settings.RAD[0], rad_final=settings.RAD[1],
rot_reg=settings.ROT_REG,
outlierprior=settings.OUTLIER_PRIOR, outlierfrac=settings.OUTLIER_FRAC,
prior_fn=None,
f_solver_factory=solver.AutoTpsSolverFactory()):
raise NotImplementedError
n_iter=settings.N_ITER, em_iter=settings.EM_ITER,
reg_init=settings.REG[0], reg_final=settings.REG[1],
rad_init=settings.RAD[0], rad_final=settings.RAD[1],
radn_init=settings.RADN[0], radn_final=settings.RADN[1],
nu_init=settings.NU[0], nu_final=settings.NU[1],
rot_reg=settings.ROT_REG,
outlierprior=settings.OUTLIER_PRIOR, outlierfrac=settings.OUTLIER_FRAC,
callback=None):
self.n_iter = n_iter
self.em_iter = em_iter
self.reg_init = reg_init
self.reg_final = reg_final
self.rad_init = rad_init
self.rad_final = rad_final
self.radn_init = radn_init
self.radn_final = radn_final
self.nu_init = nu_init
self.nu_final = nu_final
self.rot_reg = rot_reg
self.outlierprior = outlierprior
self.outlierfrac = outlierfrac

def register(self, demo, test_scene_state, callback=None):
if self.prior_fn is not None:
prior_prob_nm = self.prior_fn(demo.scene_state, test_scene_state)
else:
prior_prob_nm = None
x_nd = demo.scene_state.cloud[:,:3]
x_ld = demo.scene_state.cloud[:,:3]
u_rd = demo.scene_state.normals
z_rd = demo.scene_state.sites
y_md = test_scene_state.cloud[:,:3]
v_sd = test_scene_state.normals
z_sd = test_scene_state.sites

f, corr_lm, corr_rs = tps_experimental.tpsn_rpm(x_ld, u_rd, z_rd, y_md, v_sd, z_sd,
n_iter=self.n_iter, em_iter=self.em_iter,
reg_init=self.reg_init, reg_final=self.reg_final,
rad_init=self.rad_init, rad_final=self.rad_final,
radn_init=self.radn_init, radn_final=self.radn_final,
nu_init=nu_init, nu_final=nu_final,
nu_init=self.nu_init, nu_final=self.nu_final,
rot_reg=self.rot_reg,
outlierprior=self.outlierprior, outlierfrac=self.outlierfrac,
callback=callback)

return TpsnRpmRegistration(demo, test_scene_state, f, corr, self.rad_final)
return TpsnRpmRegistration(demo, test_scene_state, corr_lm, corr_rs, self.rad_final, self.radn_final, self.reg_final, self.rot_coef)

def cost(self, demo, test_scene_state):
raise NotImplementedError
2 changes: 0 additions & 2 deletions lfd/registration/tpsn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,6 @@ def tps_callback(i, i_em, x_ld, y_md, xtarg_ld, wt_n, f, corr_lm, rad):
# reg_init = 1
reg_final = .1
if tpsn_min_param is None:
# rad_init, reg_init, radn_init, radn_final, nu_init, nu_final = 0, 10, 0.005, 0.001, 0.01, 0.1

tpsn_min_cost = float('inf')
tpsn_min_param = None
for rad_init in [0.1, 1, 10] if tpsn_min_param_ranges is None else tpsn_min_param_ranges[0]: #[2**i for i in range(-3, 4)]: #[0.1, 1, 10]:
Expand Down

0 comments on commit 5d32e02

Please sign in to comment.