From 91525113811ed1ba1da85e76fba9a8a7b5dd4784 Mon Sep 17 00:00:00 2001 From: dgursoy Date: Thu, 4 May 2023 10:58:58 -0500 Subject: [PATCH] Fix/update align_joint --- source/tomopy/prep/alignment.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/source/tomopy/prep/alignment.py b/source/tomopy/prep/alignment.py index 215c7f4c2..abc70a93b 100644 --- a/source/tomopy/prep/alignment.py +++ b/source/tomopy/prep/alignment.py @@ -214,7 +214,7 @@ def align_seq( def align_joint( - prj, ang, fdir='.', iters=10, pad=(0, 0), + prj, ang, fdir='./', iters=10, pad=(0, 0), blur=True, center=None, algorithm='sirt', upsample_factor=10, rin=0.5, rout=0.8, save=False, debug=True): @@ -300,6 +300,9 @@ def align_joint( if algorithm != 'gridrec': extra_kwargs['num_iter'] = 1 + # Make a copy of the projections + prj_copy = prj.copy() + # Register each image frame-by-frame. for n in range(iters): @@ -307,7 +310,7 @@ def align_joint( _rec = rec # Reconstruct image. - rec = recon(prj, ang, center=center, algorithm=algorithm, + rec = recon(prj_copy, ang, center=center, algorithm=algorithm, init_recon=_rec, **extra_kwargs) # Re-project data and obtain simulated data. @@ -315,10 +318,10 @@ def align_joint( # Blur edges. if blur: - _prj = blur_edges(prj, rin, rout) + _prj = blur_edges(prj_copy, rin, rout) _sim = blur_edges(sim, rin, rout) else: - _prj = prj + _prj = prj_copy _sim = sim # Initialize error matrix per iteration. @@ -329,23 +332,27 @@ def align_joint( # Register current projection in sub-pixel precision shift, error, diffphase = phase_cross_correlation( - _prj[m], _sim[m], upsample_factor=upsample_factor) + _prj[m], _sim[m], normalization=None, + upsample_factor=upsample_factor) err[m] = np.sqrt(shift[0]*shift[0] + shift[1]*shift[1]) sx[m] += shift[0] sy[m] += shift[1] # Register current image with the simulated one - tform = tf.SimilarityTransform(translation=(shift[1], shift[0])) - prj[m] = tf.warp(prj[m], tform, order=5) + tform = tf.SimilarityTransform(translation=(sy[m], sx[m])) + prj_copy[m] = tf.warp(prj[m].copy(), tform, order=5) if debug: - print('iter=' + str(n) + ', err=' + str(np.linalg.norm(err))) + print('iter=' + str(n) + + ', err=' + str(np.linalg.norm(err) / prj.shape[0])) conv[n] = np.linalg.norm(err) if save: - write_tiff(prj, 'tmp/iters/prj', n) - write_tiff(sim, 'tmp/iters/sim', n) - write_tiff(rec, 'tmp/iters/rec', n) + write_tiff(_prj, fdir + 'tmp/iters/prj', n) + write_tiff(sim, fdir + 'tmp/iters/sim', n) + write_tiff(rec, fdir + 'tmp/iters/rec', n) + write_tiff(sx, fdir + 'tmp/iters/sx', n) + write_tiff(sy, fdir + 'tmp/iters/sy', n) # Re-normalize data prj *= scl