In [2]:
%matplotlib qt
import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import fourier_shift
from skimage.registration import phase_cross_correlation
import skimage.filters as skf
from pathlib import Path
import multiprocessing as mp
import astra
import h5py
import tifffile
import os, gc, shutil
import time
import tomopy, copy
from pystackreg import StackReg

import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

# ***1. read and preprocess data***

In [None]:
dfn = Path('/media/xiao_usb/High_res_TXM_CT_2020/fly_scan_id_48753.h5')
#ofn = Path('/media/Disk2/data/Yuan_2019_02/aligned_fly_scan_id_12275.h5')
#ofn = Path('/media/Disk2/data/Yuan_2019_02/recon_fly_scan_id_12275.tif')

f = h5py.File(dfn, 'r')
imgs = f['img_tomo'][:, 800:900, :].astype(np.float32)
bkg = f['img_bkg'][:, 800:900, :].astype(np.float32)
dark = f['img_dark'][:, 800:900, :].astype(np.float32)
theta = f['angle'][:]*np.pi/180.
f.close()

#imgs = tomopy.prep.normalize.normalize(imgs, bkg, dark)
#imgs = tomopy.prep.normalize.minus_log(imgs)
imgs[:] = ((imgs-dark.mean(axis=0))/(bkg.mean(axis=0)-dark.mean(axis=0)))[:]
imgs[:] = tomopy.prep.stripe.remove_all_stripe(imgs, la_size=161, sm_size=31)[:]
imgs[:] = -np.log(imgs)[:]
imgs[:] = np.where(np.isnan(imgs), 0, imgs)[:]
imgs[:] = np.where(np.isinf(imgs), 0, imgs)[:]

imgs = np.swapaxes(imgs, 0, 1)
ny, na, nx = imgs.shape
#print(na, ny, nx)

In [None]:
plt.figure(10)
plt.imshow(imgs[0, ...])

In [None]:
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
plt.figure(100)
#plt.imshow(skf.meijering(imgs[:, 500, :]))
#plt.imshow(skf.frangi(imgs[:, 500, :]))
plt.imshow(skf.sato(imgs[:, 500, :]))
#plt.imshow(skf.hessian(imgs[:, 500, :]))
#plt.imshow(skf.roberts_pos_diag(imgs[:, 500, :]))
#plt.imshow(skf.laplace(imgs[:, 500, :]))

# ***2. tomopy recon to find center***

In [None]:
data_center_path = "/media/xiao_usb/High_res_TXM_CT_2020/data_center"
if os.path.exists(data_center_path):
    [f.unlink() for f in Path(data_center_path).glob("*") if f.is_file()]
else:
    os.makedirs(data_center_path, mode=777)
center_shift = -80
center_shift_w = 80
tomopy.write_center(np.swapaxes(imgs, 0, 1)[:, 40:60, :], theta, dpath=data_center_path,
                 cen_range=(nx/2+center_shift, nx/2+center_shift+center_shift_w, 0.5),
                 mask = True, ratio = 1, algorithm = 'gridrec', filter_name = 'parzen')

In [None]:
cen = 1243
recon = tomopy.recon(np.swapaxes(imgs, 0, 1), theta, center=cen, algorithm='gridrec', filter_name = 'parzen')

In [None]:
plt.figure(0)
plt.imshow(recon[50, ...])

# ***3. shift data to correct center offset***

In [None]:
cen = 1243
shift = nx/2. - cen 
def shift_cen(img, shift):
    return np.real(np.fft.ifftn(fourier_shift(np.fft.fftn(img), [0, shift])))

n_cpu = os.cpu_count()
with mp.Pool(n_cpu-1) as pool:
    rlt = pool.starmap(shift_cen, [(imgs[:, ii, :], shift) for ii in np.int32(np.arange(na))])
pool.join()
pool.close()
for ii in range(na):
    imgs[:, ii, :] = rlt[ii][:, :]


In [None]:
del rlt
gc.collect()

In [None]:
plt.figure(1)
plt.imshow(imgs[:, 100, :])

# ***4. define astra data structure***

In [None]:
proj_geom = astra.creators.create_proj_geom(
    'parallel3d', 1., 1., ny, nx, theta)
# print(proj_geom)
vol_geom = astra.creators.create_vol_geom(nx, nx, ny)
# print(astra.data3d.get_geometry(proj_geom))

In [None]:
print(vol_geom['GridColCount'])

In [None]:
# reconstruction with astra
recon_id = astra.data3d.create('-vol', vol_geom, data=recon[:, ::-1, :])
#recon_id = astra.data3d.create('-vol', vol_geom, data=0)
proj_id = astra.data3d.create('-sino', proj_geom, imgs)

In [None]:
print(proj_id)
proj = astra.data3d.get(proj_id)
print(ny, nx, theta.shape)
print(imgs.shape)
print(proj.shape)

In [None]:
plt.figure(2)
plt.imshow(proj[:, 500, :])

# ***5. config astra algorithm***

In [None]:
print(time.asctime())
alg_cfg = astra.astra_dict('SIRT3D_CUDA')
#alg_cfg = astra.astra_dict('BP3D_CUDA')

alg_cfg['ProjectionDataId'] = proj_id
alg_cfg['ReconstructionDataId'] = recon_id
#alg_cfg['option'] = {}
#alg_cfg['option']['MinConstraint'] = 0
algorithm_id = astra.algorithm.create(alg_cfg)
astra.algorithm.run(algorithm_id, 20)
print(time.asctime())

In [None]:
recon = astra.data3d.get(recon_id)
plt.figure(3)
plt.imshow(recon[50, :, :])

# ***6. reproject recon***

In [None]:
#astra.functions.move_vol_geom(vol_geom, (0, shift, 0), is_relative=False)
fp_id, fp_data = astra.create_sino3d_gpu(recon, proj_geom, vol_geom)

In [None]:
plt.figure(4)
#plt.imshow(fp_data[:, 500, :])
plt.imshow(skf.sato(fp_data[:, 500, :]))

In [None]:
plt.figure(5)
plt.imshow(fp_data[:, 500, :] - proj[:, 500, :])

# ***7. register proj images***

In [None]:
def register(ref, img, thres=1e-3):
    if method is None:
        shift = phase_cross_correlation(skf.sato(ref), skf.sato(img), upsample_factor=100)
        shifted_img = np.real(np.fft.ifftn(fourier_shift(np.fft.fftn(img), shift)))
        return shift, shifted_img
    else:
        mask = (skf.sato(img)>thres)
        shift = phase_cross_correlation(skf.sato(ref), skf.sato(img), upsample_factor=100, reference_mask=mask, overlap_ratio=0.3)
        shifted_img = np.real(np.fft.ifftn(fourier_shift(np.fft.fftn(img), shift)))
        return shift, shifted_img

In [None]:
mask = (skf.sato(imgs[:, 0, :])>thres)
tem = phase_cross_correlation(skf.sato(fp_data[:, 0, :]), skf.sato(imgs[:, 0, :]), upsample_factor=100, reference_mask=mask, overlap_ratio=0.3)

In [None]:
print(tem)

In [None]:
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
print(time.asctime())

thres = 1e-3
n_cpu = os.cpu_count()
with mp.Pool(n_cpu-1) as pool:
    rlt = pool.starmap(register, [(fp_data[:, ii, :], imgs[:, ii, :], 1e-3) for ii in np.int32(np.arange(na))])
pool.join()
pool.close()
print(time.asctime())

In [None]:
print(time.asctime())
shift = np.zeros([imgs.shape[1], 2])
for ii in range(na):
    shift[ii] = rlt[ii][0]
    imgs[:, ii, :] = rlt[ii][1][:, :]
print(time.asctime())

In [None]:
print(shift[0:100])

In [None]:
help(phase_cross_correlation)

# ***8. clean up astra***

In [None]:
# astra: cleanup
astra.algorithm.delete(algorithm_id)
astra.data3d.delete(recon_id)
astra.data3d.delete(proj_id)
astra.data3d.delete(fp_id)
astra.functions.clear()

In [None]:
help(astra.geom_postalignment)

# ***9. integrated alignment***

In [11]:
sli_s = 800
sli_e = 900
itr = 4
thres = 3e-4
method = 'sr'
mode = 'RIGID_BODY' # 'TRANSLATION', 'RIGID_BODY', 'SCALED_ROTATION', 'AFFINE', 'BILINEAR'
"""

    translation
    rigid body (translation + rotation)
    scaled rotation (translation + rotation + scaling)
    affine (translation + rotation + scaling + shearing)
    bilinear (non-linear transformation; does not preserve straight lines)
    
"""

def register(ref, img, method='pcc', thres=1e-3, mode=None):
    if method == 'pcc':
        if thres is None:
            shift, _, _ = phase_cross_correlation(skf.sato(ref), skf.sato(img), upsample_factor=100)
            #print(shift)
            shifted_img = np.real(np.fft.ifftn(fourier_shift(np.fft.fftn(img), shift)))
            return shift, shifted_img
        else:
            mask = (skf.sato(img)>thres)
            shift = phase_cross_correlation(skf.sato(ref), skf.sato(img), upsample_factor=100, reference_mask=mask, overlap_ratio=0.3)
            #print(shift)
            shifted_img = np.real(np.fft.ifftn(fourier_shift(np.fft.fftn(img), shift)))
            return shift, shifted_img
    elif method == 'sr':
        if mode.upper() == 'TRANSLATION':
            sr = StackReg(StackReg.TRANSLATION)
        elif mode.upper() == 'RIGID_BODY':
            sr = StackReg(StackReg.RIGID_BODY)
        elif mode.upper() == 'SCALED_ROTATION':
            sr = StackReg(StackReg.SCALED_ROTATION)
        elif mode.upper() == 'AFFINE':
            sr = StackReg(StackReg.AFFINE)
        elif mode.upper() == 'BILINEAR':
            sr = StackReg(StackReg.BILINEAR)
        shift = sr.register(skf.sato(ref), skf.sato(img))
        shifted_img = sr.transform(img, tmat=shift)
        return shift, shifted_img

basename = 'fly_scan_id_48753_step_0.05_cen_-1.75_out_plane_series'
dfn = Path(f'/media/xiao_usb/High_res_TXM_CT_2020/{basename}.h5')
ofn = Path(f'/media/xiao_usb/High_res_TXM_CT_2020/{basename}_{method}_{mode}.h5')

with h5py.File(dfn, 'r') as f:
    imgs = f['img_tomo'][:, sli_s:sli_e, :].astype(np.float32)
    bkg = f['img_bkg'][:, sli_s:sli_e, :].astype(np.float32)
    dark = f['img_dark'][:, sli_s:sli_e, :].astype(np.float32)
    theta = f['angle'][:]*np.pi/180.
imgs[:] = ((imgs-dark.mean(axis=0))/(bkg.mean(axis=0)-dark.mean(axis=0)))[:]
imgs[:] = tomopy.prep.stripe.remove_all_stripe(imgs, la_size=161, sm_size=31)[:]
imgs[:] = -np.log(imgs)[:]
imgs[:] = np.where(np.isnan(imgs), 0, imgs)[:]
imgs[:] = np.where(np.isinf(imgs), 0, imgs)[:]

"""
with h5py.File(dfn, 'r') as f:
    imgs = f['/proj_out_plane_corr/tilted_proj/tilt_-1.9_deg_proj'][:].astype(np.float32)
with h5py.File('/media/xiao_usb/High_res_TXM_CT_2020/fly_scan_id_48753.h5', 'r') as f:
    theta = f['angle'][:]*np.pi/180.
"""
    
with h5py.File(ofn, 'w') as f:
    if '/proj_corr' in f:
        del f['/proj_corr']
    g0 = f.create_group('/proj_corr')
    g0.create_group('shift_by_reproj') 
    g0.create_group('corrected_proj')
    g0.create_group('corrected_recon')
    g02 = g0.create_group('proj_corr_config')
    g02.create_dataset('method', data=str(method))
    g02.create_dataset('mode', data=str(mode))
    g02.create_dataset('slice', data=[sli_s, sli_e])
    if thres is None:
        g02.create_dataset('thres', data=str(thres))
    else:
        g02.create_dataset('thres', data=thres)        



imgs = np.swapaxes(imgs, 0, 1)
ny, na, nx = imgs.shape

cen = 1243
offset = nx/2. - cen 
def shift_cen(img, offset):
    return np.real(np.fft.ifftn(fourier_shift(np.fft.fftn(img), [0, offset])))

n_cpu = os.cpu_count()
with mp.Pool(n_cpu-1) as pool:
    rlt = pool.starmap(shift_cen, [(imgs[:, ii, :], offset) for ii in np.int32(np.arange(na))])
pool.join()
pool.close()
for ii in range(na):
    imgs[:, ii, :] = rlt[ii][:, :]
del(rlt)
gc.collect()

if method == 'pcc':
    shift = np.zeros([imgs.shape[1], 2])
elif method == 'sr':
    shift = np.zeros([imgs.shape[1], 3, 3])

cen = 1280
recon = tomopy.recon(np.swapaxes(imgs, 0, 1), theta, center=cen, algorithm='gridrec', filter_name = 'parzen')[:, ::-1, :]

ori_imgs = copy.deepcopy(imgs)

proj_geom = astra.creators.create_proj_geom(
    'parallel3d', 1., 1., ny, nx, theta)
vol_geom = astra.creators.create_vol_geom(nx, nx, ny)

for ii in range(itr):
    print(f"{ii}th iteration starts at {time.asctime()}")
    recon_id = astra.data3d.create('-vol', vol_geom, data=recon)
    proj_id = astra.data3d.create('-sino', proj_geom, imgs)

    print(f"    astra recons starts at {time.asctime()}")
    alg_cfg = astra.astra_dict('SIRT3D_CUDA')
    #alg_cfg = astra.astra_dict('BP3D_CUDA')

    alg_cfg['ProjectionDataId'] = proj_id
    alg_cfg['ReconstructionDataId'] = recon_id
    #alg_cfg['option'] = {}
    #alg_cfg['option']['MinConstraint'] = 0
    algorithm_id = astra.algorithm.create(alg_cfg)
    astra.algorithm.run(algorithm_id, 20)
    print(f"    astra recons finishes at {time.asctime()}")

    recon = astra.data3d.get(recon_id)
    #plt.figure(3)
    #plt.imshow(recon[50, :, :])

    fp_id, fp_data = astra.create_sino3d_gpu(recon, proj_geom, vol_geom)

    warnings.filterwarnings('ignore')
    warnings.simplefilter('ignore')
    print(f"    registration starts at {time.asctime()}")

    n_cpu = os.cpu_count()
    with mp.Pool(n_cpu-1) as pool:
        rlt = pool.starmap(register, [(fp_data[:, ii, :], imgs[:, ii, :], method, thres, mode) for ii in np.int32(np.arange(na))])
    pool.join()
    pool.close()
    print(f"    astra recons finishes at {time.asctime()}")

    for jj in range(na):
        shift[jj] = rlt[jj][0]
        imgs[:, jj, :] = rlt[jj][1][:, :]
    del(rlt)
    gc.collect()
    
    with h5py.File(ofn, 'a') as f:
        g0 = f['proj_corr']       
        g0['shift_by_reproj'] .create_dataset(f'iter_{str(ii).zfill(2)}', data=shift.astype(np.float32), dtype=np.float32)
        g0['corrected_proj'] .create_dataset(f'iter_{str(ii).zfill(2)}_shifted_images', data=imgs.astype(np.float32), dtype=np.float32)
        g0['corrected_recon'] .create_dataset(f'iter_{str(ii).zfill(2)}_recon', data=recon.astype(np.float32), dtype=np.float32)

    astra.algorithm.delete(algorithm_id)
    astra.data3d.delete(recon_id)
    astra.data3d.delete(proj_id)
    astra.data3d.delete(fp_id)
    astra.functions.clear()
    print(f"{ii}th iteration finishes at {time.asctime()}")
    
recon_id = astra.data3d.create('-vol', vol_geom, data=recon[:, ::-1, :])
proj_id = astra.data3d.create('-sino', proj_geom, imgs)

print(f"    astra recons starts at {time.asctime()}")
alg_cfg = astra.astra_dict('SIRT3D_CUDA')
#alg_cfg = astra.astra_dict('BP3D_CUDA')

alg_cfg['ProjectionDataId'] = proj_id
alg_cfg['ReconstructionDataId'] = recon_id
#alg_cfg['option'] = {}
#alg_cfg['option']['MinConstraint'] = 0
algorithm_id = astra.algorithm.create(alg_cfg)
astra.algorithm.run(algorithm_id, 20)
print(f"    astra recons finishes at {time.asctime()}")

recon = astra.data3d.get(recon_id)

astra.algorithm.delete(algorithm_id)
astra.data3d.delete(recon_id)
astra.data3d.delete(proj_id)
astra.data3d.delete(fp_id)
astra.functions.clear()

0th iteration starts at Wed Sep  9 00:01:04 2020
    astra recons starts at Wed Sep  9 00:01:07 2020
    astra recons finishes at Wed Sep  9 00:05:50 2020
    registration starts at Wed Sep  9 00:06:01 2020
    astra recons finishes at Wed Sep  9 00:11:23 2020
0th iteration finishes at Wed Sep  9 00:11:35 2020
1th iteration starts at Wed Sep  9 00:11:35 2020
    astra recons starts at Wed Sep  9 00:11:38 2020
    astra recons finishes at Wed Sep  9 00:16:22 2020
    registration starts at Wed Sep  9 00:16:33 2020
    astra recons finishes at Wed Sep  9 00:21:52 2020
1th iteration finishes at Wed Sep  9 00:22:04 2020
2th iteration starts at Wed Sep  9 00:22:04 2020
    astra recons starts at Wed Sep  9 00:22:07 2020
    astra recons finishes at Wed Sep  9 00:26:52 2020
    registration starts at Wed Sep  9 00:27:03 2020
    astra recons finishes at Wed Sep  9 00:32:24 2020
2th iteration finishes at Wed Sep  9 00:32:40 2020
3th iteration starts at Wed Sep  9 00:32:40 2020
    astra recon

In [None]:
plt.figure(1000)
plt.imshow(recon[60, :, :])

In [None]:
start=500
print(shift[start:start+100])

In [None]:
plt.figure(50)
plt.plot(shift[:, 0])
plt.figure(51)
plt.plot(shift[:, 1])

In [None]:
plt.figure(52)
plt.imshow(skf.sato(fp_data[:, 500, :])>3e-4)
plt.figure(53)
plt.imshow(skf.sato(imgs[:, 500, :])>3e-4)

# ***10. apply correction to entire dataset***

In [None]:
print(f'correcting entire dataset starts at {time.asctime()}')

dfn = Path('/media/xiao_usb/High_res_TXM_CT_2020/fly_scan_id_48753.h5')
ofn = Path('/media/xiao_usb/High_res_TXM_CT_2020/fly_scan_id_48753_sr_RIGID_BODY.h5')
with h5py.File(ofn, 'a') as fo:
    with h5py.File(dfn, 'r') as fi:
        angle = fi['/angle'][:]
        if fo['/angle'] is not None:
            del fo['/angle']
        fo.create_dataset('/angle', data=angle.astype(np.float32))
        
        X_eng = fi['/X_eng'][()]
        if fo['/X_eng'] is not None:
            del fo['/X_eng']
        fo.create_dataset('/X_eng', data=X_eng.astype(np.float32))
        
        img_bkg = fi['/img_bkg'][:]
        if fo['/img_bkg'] is not None:
            del fo['/img_bkg']
        fo.create_dataset('/img_bkg', data=img_bkg.astype(np.uint16), dtype=np.uint16)
        
        img_dark = fi['/img_dark'][:]
        if fo['/img_dark'] is not None:
            del fo['/img_dark']
        fo.create_dataset('/img_dark', data=img_dark.astype(np.uint16), dtype=np.uint16)
        
        if fo['/img_tomo'] is not None:
            del fo['/img_tomo']
        g0 = fo.create_dataset('/img_tomo', shape=(fi['/img_tomo'].shape), dtype=np.uint16)
        
        g1 = fo['/proj_corr']
        
        """
        g1 = fo.create_group('proj_corr_config')
        mthod = fi['/proj_corr/proj_corr_config/method'][()]
        g1.create_dataset('method', data=str(method))
        mode = fi['/proj_corr/proj_corr_config/mode'][()]
        g1.create_dataset('mode', data=str(mode))
        thres = fi['/proj_corr/proj_corr_config/thres'][()]
        g1.create_dataset('thres', data=thres)
        """
        
        if fo['/proj_corr/proj_corr_config/method'][()] == 'sr':            
            if mode.upper() == 'TRANSLATION':
                sr = StackReg(StackReg.TRANSLATION)
            elif mode.upper() == 'RIGID_BODY':
                sr = StackReg(StackReg.RIGID_BODY)
            elif mode.upper() == 'SCALED_ROTATION':
                sr = StackReg(StackReg.SCALED_ROTATION)
            elif mode.upper() == 'AFFINE':
                sr = StackReg(StackReg.AFFINE)
            elif mode.upper() == 'BILINEAR':
                sr = StackReg(StackReg.BILINEAR)
            
            i = len(fo['/proj_corr/shift_by_reproj'])
            n = fo['/proj_corr/shift_by_reproj/iter_00'].shape[0]
            g10 = g1.create_dataset('overall_shift', shape=(n, 3, 3), dtype=np.float32)
            for ii in range(n):
                shift = np.identity(3)
                tem = fo[f'/proj_corr/shift_by_reproj/iter_{str(jj).zfill(2)}'][ii]
                for jj in range(i):
                    shift = np.matmul(shift, tem[jj])
                g10[ii] = shift
                g0[ii] = np.round(sr.transform(fi['img_tomo'][ii], tmat=shift)).astype(np.uint16)
        elif fo['/proj_corr/proj_corr_config/method'][()] == 'pcc':
            i = len(fo['/proj_corr/shift_by_reproj'])
            n = fo['/proj_corr/shift_by_reproj/iter_00'].shape[0]            
            g10 = g1.create_dataset('overall_shift', shape=(n, 2), dtype=np.float32)
            for ii in range(n):
                shift = 0
                for jj in range(i):
                    shift += fo[f'/proj_corr/shift_by_reproj/iter_{str(jj).zfill(2)}'][ii]  
                g10[ii] = shift
                g0[ii] = np.round(np.real(np.fft.ifftn(fourier_shift(np.fft.fftn(fi['/img_tomo'][ii]), shift)))).astype(np.uint16)
print(f'correcting entire dataset ends at {time.asctime()}')

In [None]:
a = np.tile(np.identity(3), (10, 1, 1))
b = np.repeat(np.identity(3), 10).reshape([3, 3, 10])
print(a.shape, b.shape)
print(a)

In [None]:
help(np.repeat)

# ***11. out of plane angle correction***

In [10]:
itr = 25
range_cen = -1.75
step = 0.05 # degree
#method = 'sr'
#mode = 'RIGID_BODY' # 'TRANSLATION', 'RIGID_BODY', 'SCALED_ROTATION', 'AFFINE', 'BILINEAR'
"""

    translation
    rigid body (translation + rotation)
    scaled rotation (translation + rotation + scaling)
    affine (translation + rotation + scaling + shearing)
    bilinear (non-linear transformation; does not preserve straight lines)
    
"""

def register(ref, img, method='pcc', thres=1e-3, mode=None):
    if method == 'pcc':
        if thres is None:
            shift, _, _ = phase_cross_correlation(skf.sato(ref), skf.sato(img), upsample_factor=100)
            #print(shift)
            shifted_img = np.real(np.fft.ifftn(fourier_shift(np.fft.fftn(img), shift)))
            return shift, shifted_img
        else:
            mask = (skf.sato(img)>thres)
            shift = phase_cross_correlation(skf.sato(ref), skf.sato(img), upsample_factor=100, reference_mask=mask, overlap_ratio=0.3)
            #print(shift)
            shifted_img = np.real(np.fft.ifftn(fourier_shift(np.fft.fftn(img), shift)))
            return shift, shifted_img
    elif method == 'sr':
        if mode.upper() == 'TRANSLATION':
            sr = StackReg(StackReg.TRANSLATION)
        elif mode.upper() == 'RIGID_BODY':
            sr = StackReg(StackReg.RIGID_BODY)
        elif mode.upper() == 'SCALED_ROTATION':
            sr = StackReg(StackReg.SCALED_ROTATION)
        elif mode.upper() == 'AFFINE':
            sr = StackReg(StackReg.AFFINE)
        elif mode.upper() == 'BILINEAR':
            sr = StackReg(StackReg.BILINEAR)
        shift = sr.register(skf.sato(ref), skf.sato(img))
        shifted_img = sr.transform(img, tmat=shift)
        return shift, shifted_img
    
#ifn = Path('/media/xiao_usb/High_res_TXM_CT_2020/fly_scan_id_48753_sr_RIGID_BODY.h5')
#ofn = Path('/media/xiao_usb/High_res_TXM_CT_2020/fly_scan_id_48753_sr_RIGID_BODY.h5')
ifn = Path('/media/xiao_usb/High_res_TXM_CT_2020/fly_scan_id_48753.h5')
ofn = Path(f'/media/xiao_usb/High_res_TXM_CT_2020/fly_scan_id_48753_step_{step}_cen_{range_cen}_out_plane_series.h5')

sli_s = 800
sli_e = 900

with h5py.File(ifn, 'r') as f:
    imgs = f['img_tomo'][:, sli_s:sli_e, :].astype(np.float32)
    bkg = f['img_bkg'][:, sli_s:sli_e, :].astype(np.float32)
    dark = f['img_dark'][:, sli_s:sli_e, :].astype(np.float32)
    theta = f['angle'][:]*np.pi/180.
    #method = f['/proj_corr/proj_corr_config/method'][()]
    #mode = f['/proj_corr/proj_corr_config/mode'][()] # 'TRANSLATION', 'RIGID_BODY', 'SCALED_ROTATION', 'AFFINE', 'BILINEAR'
    
with h5py.File(ofn, 'a') as f:
    if '/proj_out_plane_corr' in f:
        del f['/proj_out_plane_corr']
    g0 = f.create_group('/proj_out_plane_corr')
    g01 = g0.create_group('tilted_proj')
    g02 = g0.create_group('proj_corr_config')
    g02.create_dataset('method', data='astra')
    g02.create_dataset('mode', data='proj_vector correction')
    g02.create_dataset('step', data=step, dtype=np.float32)
    g02.create_dataset('angle_range', data=step, dtype=np.float32)
         

imgs[:] = ((imgs-dark.mean(axis=0))/(bkg.mean(axis=0)-dark.mean(axis=0)))[:]
imgs[:] = tomopy.prep.stripe.remove_all_stripe(imgs, la_size=161, sm_size=31)[:]
imgs[:] = -np.log(imgs)[:]
imgs[:] = np.where(np.isnan(imgs), 0, imgs)[:]
imgs[:] = np.where(np.isinf(imgs), 0, imgs)[:]

imgs = np.swapaxes(imgs, 0, 1)
ny, na, nx = imgs.shape

cen = 1243
offset = nx/2. - cen 
def shift_cen(img, offset):
    return np.real(np.fft.ifftn(fourier_shift(np.fft.fftn(img), [0, offset])))

n_cpu = os.cpu_count()
with mp.Pool(n_cpu-1) as pool:
    rlt = pool.starmap(shift_cen, [(imgs[:, ii, :], offset) for ii in np.int32(np.arange(na))])
    
pool.join()
pool.close()

for ii in range(na):
    imgs[:, ii, :] = rlt[ii][:, :]
del(rlt)
gc.collect()

cen = 1280

ori_imgs = copy.deepcopy(imgs)

proj_geom = astra.creators.create_proj_geom('parallel3d', 1., 1., ny, nx, theta)
vol_geom = astra.creators.create_vol_geom(nx, nx, ny)

for ii in range(itr):
    print(f"{ii}th iteration starts at {time.asctime()}")
    recon_id = astra.data3d.create('-vol', vol_geom, data=0)
    proj_geom_vec = astra.functions.geom_2vec(proj_geom)
    
    alpha = ((ii-np.int(itr/2))*step+range_cen)*np.pi/180
    """
    # old geometry
    proj_geom_vec['Vectors'][:, 2] += np.sin(alpha)
    proj_geom_vec['Vectors'][:, 9] *= np.cos(np.pi/2-alpha)
    proj_geom_vec['Vectors'][:, 10] *= np.cos(np.pi/2-alpha)
    """
    
    proj_geom_vec['Vectors'][:, 0] *= np.sin(np.pi/2-alpha)
    proj_geom_vec['Vectors'][:, 1] *= np.sin(np.pi/2-alpha)
    proj_geom_vec['Vectors'][:, 2] -= np.cos(np.pi/2-alpha)
    
    proj_geom_vec['Vectors'][:, 3] *= np.sin(np.pi/2-alpha)
    proj_geom_vec['Vectors'][:, 4] *= np.sin(np.pi/2-alpha)
    proj_geom_vec['Vectors'][:, 5] -= np.cos(np.pi/2-alpha)
    
    proj_geom_vec['Vectors'][:, 9] *= np.cos(np.pi/2-alpha)
    proj_geom_vec['Vectors'][:, 10] *= np.cos(np.pi/2-alpha)
    proj_geom_vec['Vectors'][:, 11] *= np.sin(np.pi/2-alpha)
    
    proj_id = astra.data3d.create('-sino', proj_geom_vec, imgs)

    print(f"    astra recons starts at {time.asctime()}")
    alg_cfg = astra.astra_dict('CGLS3D_CUDA')
    #alg_cfg = astra.astra_dict('BP3D_CUDA')

    alg_cfg['ProjectionDataId'] = proj_id
    alg_cfg['ReconstructionDataId'] = recon_id
    #alg_cfg['option'] = {}
    #alg_cfg['option']['MinConstraint'] = 0
    algorithm_id = astra.algorithm.create(alg_cfg)
    astra.algorithm.run(algorithm_id, 50)
    print(f"    astra recons finishes at {time.asctime()}")

    recon = astra.data3d.get(recon_id)
    #plt.figure(3)
    #plt.imshow(recon[50, :, :])

    fp_id, fp_data = astra.create_sino3d_gpu(recon, proj_geom, vol_geom)

    """
    warnings.filterwarnings('ignore')
    warnings.simplefilter('ignore')
    print(f"    registration starts at {time.asctime()}")

    n_cpu = os.cpu_count()
    with mp.Pool(n_cpu-1) as pool:
        rlt = pool.starmap(register, [(fp_data[:, ii, :], imgs[:, ii, :], method, thres, mode) for ii in np.int32(np.arange(na))])
    pool.join()
    pool.close()
    print(f"    astra recons finishes at {time.asctime()}")

    for jj in range(na):
        shift[jj] = rlt[jj][0]
        imgs[:, jj, :] = rlt[jj][1][:, :]
    del(rlt)
    gc.collect()
    """
    
    with h5py.File(ofn, 'a') as f:
        g0 = f['/proj_out_plane_corr/tilted_proj']       
        g0.create_dataset(f'tilt_{str((ii-np.int(itr/2))*step+range_cen).zfill(2)}_deg_proj', data=np.swapaxes(fp_data, 0, 1).astype(np.float32), dtype=np.float32)
        g0.create_dataset(f'tilt_{str((ii-np.int(itr/2))*step+range_cen).zfill(2)}_deg_recon', data=recon.astype(np.float32), dtype=np.float32)

    astra.algorithm.delete(algorithm_id)
    astra.data3d.delete(recon_id)
    astra.data3d.delete(proj_id)
    astra.data3d.delete(fp_id)
    astra.functions.clear()
    print(f"{ii}th iteration finishes at {time.asctime()}")

"""
recon_id = astra.data3d.create('-vol', vol_geom, data=recon[:, ::-1, :])
proj_id = astra.data3d.create('-sino', proj_geom, imgs)

print(f"    astra recons starts at {time.asctime()}")
alg_cfg = astra.astra_dict('SIRT3D_CUDA')
#alg_cfg = astra.astra_dict('BP3D_CUDA')

alg_cfg['ProjectionDataId'] = proj_id
alg_cfg['ReconstructionDataId'] = recon_id
#alg_cfg['option'] = {}
#alg_cfg['option']['MinConstraint'] = 0
algorithm_id = astra.algorithm.create(alg_cfg)
astra.algorithm.run(algorithm_id, 20)
print(f"    astra recons finishes at {time.asctime()}")

recon = astra.data3d.get(recon_id)
 
astra.data3d.delete(recon_id)
astra.data3d.delete(proj_id)
astra.data3d.delete(fp_id)
astra.functions.clear()
"""

0th iteration starts at Tue Sep  8 08:32:37 2020
    astra recons starts at Tue Sep  8 08:32:39 2020
    astra recons finishes at Tue Sep  8 08:47:33 2020
0th iteration finishes at Tue Sep  8 08:47:53 2020
1th iteration starts at Tue Sep  8 08:47:53 2020
    astra recons starts at Tue Sep  8 08:47:56 2020
    astra recons finishes at Tue Sep  8 09:02:58 2020
1th iteration finishes at Tue Sep  8 09:03:18 2020
2th iteration starts at Tue Sep  8 09:03:18 2020
    astra recons starts at Tue Sep  8 09:03:20 2020
    astra recons finishes at Tue Sep  8 09:18:25 2020
2th iteration finishes at Tue Sep  8 09:18:45 2020
3th iteration starts at Tue Sep  8 09:18:45 2020
    astra recons starts at Tue Sep  8 09:18:48 2020
    astra recons finishes at Tue Sep  8 09:33:55 2020
3th iteration finishes at Tue Sep  8 09:34:15 2020
4th iteration starts at Tue Sep  8 09:34:15 2020
    astra recons starts at Tue Sep  8 09:34:17 2020
    astra recons finishes at Tue Sep  8 09:49:25 2020
4th iteration finishe

'\nrecon_id = astra.data3d.create(\'-vol\', vol_geom, data=recon[:, ::-1, :])\nproj_id = astra.data3d.create(\'-sino\', proj_geom, imgs)\n\nprint(f"    astra recons starts at {time.asctime()}")\nalg_cfg = astra.astra_dict(\'SIRT3D_CUDA\')\n#alg_cfg = astra.astra_dict(\'BP3D_CUDA\')\n\nalg_cfg[\'ProjectionDataId\'] = proj_id\nalg_cfg[\'ReconstructionDataId\'] = recon_id\n#alg_cfg[\'option\'] = {}\n#alg_cfg[\'option\'][\'MinConstraint\'] = 0\nalgorithm_id = astra.algorithm.create(alg_cfg)\nastra.algorithm.run(algorithm_id, 20)\nprint(f"    astra recons finishes at {time.asctime()}")\n\nrecon = astra.data3d.get(recon_id)\n \nastra.data3d.delete(recon_id)\nastra.data3d.delete(proj_id)\nastra.data3d.delete(fp_id)\nastra.functions.clear()\n'

In [None]:
proj_geom_vec = astra.functions.geom_2vec(proj_geom)
proj_geom_vec['Vectors'][:,2] += np.sin(np.pi*0.05/180)

# ***12. 4-degree correction***

In [12]:
sli_s = 800
sli_e = 900
itr = 4
thres = 3e-4
method = 'sr'
mode = 'RIGID_BODY' # 'TRANSLATION', 'RIGID_BODY', 'SCALED_ROTATION', 'AFFINE', 'BILINEAR'

basename = 'fly_scan_id_48753'
tilt_ang = -1.9  # degree
dfn = Path(f'/media/xiao_usb/High_res_TXM_CT_2020/{basename}.h5')
ofn = Path(f'/media/xiao_usb/High_res_TXM_CT_2020/{basename}_tilt_{tilt_ang}_{method}_{mode}.h5')
"""

    translation
    rigid body (translation + rotation)
    scaled rotation (translation + rotation + scaling)
    affine (translation + rotation + scaling + shearing)
    bilinear (non-linear transformation; does not preserve straight lines)
    
"""

def register(ref, img, method='pcc', thres=1e-3, mode=None):
    if method == 'pcc':
        if thres is None:
            shift, _, _ = phase_cross_correlation(skf.sato(ref), skf.sato(img), upsample_factor=100)
            #print(shift)
            shifted_img = np.real(np.fft.ifftn(fourier_shift(np.fft.fftn(img), shift)))
            return shift, shifted_img
        else:
            mask = (skf.sato(img)>thres)
            shift = phase_cross_correlation(skf.sato(ref), skf.sato(img), upsample_factor=100, reference_mask=mask, overlap_ratio=0.3)
            #print(shift)
            shifted_img = np.real(np.fft.ifftn(fourier_shift(np.fft.fftn(img), shift)))
            return shift, shifted_img
    elif method == 'sr':
        if mode.upper() == 'TRANSLATION':
            sr = StackReg(StackReg.TRANSLATION)
        elif mode.upper() == 'RIGID_BODY':
            sr = StackReg(StackReg.RIGID_BODY)
        elif mode.upper() == 'SCALED_ROTATION':
            sr = StackReg(StackReg.SCALED_ROTATION)
        elif mode.upper() == 'AFFINE':
            sr = StackReg(StackReg.AFFINE)
        elif mode.upper() == 'BILINEAR':
            sr = StackReg(StackReg.BILINEAR)
        shift = sr.register(skf.sato(ref), skf.sato(img))
        shifted_img = sr.transform(img, tmat=shift)
        return shift, shifted_img

with h5py.File(dfn, 'r') as f:
    imgs = f['img_tomo'][:, sli_s:sli_e, :].astype(np.float32)
    bkg = f['img_bkg'][:, sli_s:sli_e, :].astype(np.float32)
    dark = f['img_dark'][:, sli_s:sli_e, :].astype(np.float32)
    theta = f['angle'][:]*np.pi/180.
with h5py.File(ofn, 'w') as f:
    if '/proj_corr' in f:
        del f['/proj_corr']
    g0 = f.create_group('/proj_corr')
    g0.create_group('shift_by_reproj') 
    g0.create_group('corrected_proj')
    g0.create_group('corrected_recon')
    g02 = g0.create_group('proj_corr_config')
    g02.create_dataset('method', data=str(method))
    g02.create_dataset('mode', data=str(mode))
    g02.create_dataset('tilt', data=tilt_ang)
    g02.create_dataset('slice', data=[sli_s, sli_e])
    if thres is None:
        g02.create_dataset('thres', data=str(thres))
    else:
        g02.create_dataset('thres', data=thres)        

imgs[:] = ((imgs-dark.mean(axis=0))/(bkg.mean(axis=0)-dark.mean(axis=0)))[:]
imgs[:] = tomopy.prep.stripe.remove_all_stripe(imgs, la_size=161, sm_size=31)[:]
imgs[:] = -np.log(imgs)[:]
imgs[:] = np.where(np.isnan(imgs), 0, imgs)[:]
imgs[:] = np.where(np.isinf(imgs), 0, imgs)[:]

imgs = np.swapaxes(imgs, 0, 1)
ny, na, nx = imgs.shape

cen = 1243
offset = nx/2. - cen 
def shift_cen(img, offset):
    return np.real(np.fft.ifftn(fourier_shift(np.fft.fftn(img), [0, offset])))

n_cpu = os.cpu_count()
with mp.Pool(n_cpu-1) as pool:
    rlt = pool.starmap(shift_cen, [(imgs[:, ii, :], offset) for ii in np.int32(np.arange(na))])
pool.join()
pool.close()
for ii in range(na):
    imgs[:, ii, :] = rlt[ii][:, :]
del(rlt)
gc.collect()

if method == 'pcc':
    shift = np.zeros([imgs.shape[1], 2])
elif method == 'sr':
    shift = np.zeros([imgs.shape[1], 3, 3])

cen = 1280
#recon = tomopy.recon(np.swapaxes(imgs, 0, 1), theta, center=cen, algorithm='gridrec', filter_name = 'parzen')[:, ::-1, :]

ori_imgs = copy.deepcopy(imgs)

proj_geom = astra.creators.create_proj_geom('parallel3d', 1., 1., ny, nx, theta)
vol_geom = astra.creators.create_vol_geom(nx, nx, ny)

for ii in range(itr):
    print(f"{ii}th iteration starts at {time.asctime()}")
    recon_id = astra.data3d.create('-vol', vol_geom, data=0)
    proj_geom_vec = astra.functions.geom_2vec(proj_geom)
    
    #proj_geom_vec['Vectors'][:, 2] += np.sin(tilt_ang*np.pi/180)
    #proj_geom_vec['Vectors'][:, 9] *= np.cos(np.pi/2-tilt_ang*np.pi/180)
    #proj_geom_vec['Vectors'][:, 10] *= np.cos(np.pi/2-tilt_ang*np.pi/180)
    
    proj_geom_vec['Vectors'][:, 0] *= np.sin(np.pi/2-alpha)
    proj_geom_vec['Vectors'][:, 1] *= np.sin(np.pi/2-alpha)
    proj_geom_vec['Vectors'][:, 2] -= np.cos(np.pi/2-alpha)
    
    proj_geom_vec['Vectors'][:, 3] *= np.sin(np.pi/2-alpha)
    proj_geom_vec['Vectors'][:, 4] *= np.sin(np.pi/2-alpha)
    proj_geom_vec['Vectors'][:, 5] -= np.cos(np.pi/2-alpha)
    
    proj_geom_vec['Vectors'][:, 9] *= np.cos(np.pi/2-alpha)
    proj_geom_vec['Vectors'][:, 10] *= np.cos(np.pi/2-alpha)
    proj_geom_vec['Vectors'][:, 11] *= np.sin(np.pi/2-alpha)
    
    proj_id = astra.data3d.create('-sino', proj_geom_vec, imgs)

    print(f"    astra recons starts at {time.asctime()}")
    #alg_cfg = astra.astra_dict('SIRT3D_CUDA')
    alg_cfg = astra.astra_dict('CGLS3D_CUDA')
    #alg_cfg = astra.astra_dict('BP3D_CUDA')

    alg_cfg['ProjectionDataId'] = proj_id
    alg_cfg['ReconstructionDataId'] = recon_id
    #alg_cfg['option'] = {}
    #alg_cfg['option']['MinConstraint'] = 0
    algorithm_id = astra.algorithm.create(alg_cfg)
    astra.algorithm.run(algorithm_id, 50)
    print(f"    astra recons finishes at {time.asctime()}")

    recon = astra.data3d.get(recon_id)
    #plt.figure(3)
    #plt.imshow(recon[50, :, :])

    #fp_id, fp_data = astra.create_sino3d_gpu(recon, proj_geom, vol_geom)
    fp_id, fp_data = astra.create_sino3d_gpu(recon, proj_geom_vec, vol_geom)

    warnings.filterwarnings('ignore')
    warnings.simplefilter('ignore')
    print(f"    registration starts at {time.asctime()}")

    n_cpu = os.cpu_count()
    with mp.Pool(n_cpu-1) as pool:
        rlt = pool.starmap(register, [(fp_data[:, ii, :], imgs[:, ii, :], method, thres, mode) for ii in np.int32(np.arange(na))])
    pool.join()
    pool.close()
    print(f"    astra recons finishes at {time.asctime()}")

    for jj in range(na):
        shift[jj] = rlt[jj][0]
        imgs[:, jj, :] = rlt[jj][1][:, :]
    del(rlt)
    gc.collect()
    
    with h5py.File(ofn, 'a') as f:
        g0 = f['proj_corr']       
        g0['shift_by_reproj'] .create_dataset(f'iter_{str(ii).zfill(2)}', data=shift.astype(np.float32), dtype=np.float32)
        g0['corrected_proj'] .create_dataset(f'iter_{str(ii).zfill(2)}_shifted_images', data=imgs.astype(np.float32), dtype=np.float32)
        g0['corrected_recon'] .create_dataset(f'iter_{str(ii).zfill(2)}_recon', data=recon.astype(np.float32), dtype=np.float32)

    astra.algorithm.delete(algorithm_id)
    astra.data3d.delete(recon_id)
    astra.data3d.delete(proj_id)
    astra.data3d.delete(fp_id)
    astra.functions.clear()
    print(f"{ii}th iteration finishes at {time.asctime()}")
    
recon_id = astra.data3d.create('-vol', vol_geom, data=recon[:, ::-1, :])
proj_id = astra.data3d.create('-sino', proj_geom, imgs)

print(f"    astra recons starts at {time.asctime()}")
alg_cfg = astra.astra_dict('SIRT3D_CUDA')
#alg_cfg = astra.astra_dict('BP3D_CUDA')

alg_cfg['ProjectionDataId'] = proj_id
alg_cfg['ReconstructionDataId'] = recon_id
#alg_cfg['option'] = {}
#alg_cfg['option']['MinConstraint'] = 0
algorithm_id = astra.algorithm.create(alg_cfg)
astra.algorithm.run(algorithm_id, 20)
print(f"    astra recons finishes at {time.asctime()}")

recon = astra.data3d.get(recon_id)

astra.algorithm.delete(algorithm_id)
astra.data3d.delete(recon_id)
astra.data3d.delete(proj_id)
astra.data3d.delete(fp_id)
astra.functions.clear()

0th iteration starts at Wed Sep  9 07:44:32 2020
    astra recons starts at Wed Sep  9 07:44:34 2020
    astra recons finishes at Wed Sep  9 08:00:03 2020
    registration starts at Wed Sep  9 08:00:18 2020
    astra recons finishes at Wed Sep  9 08:05:43 2020
0th iteration finishes at Wed Sep  9 08:05:53 2020
1th iteration starts at Wed Sep  9 08:05:53 2020
    astra recons starts at Wed Sep  9 08:05:55 2020
    astra recons finishes at Wed Sep  9 08:21:27 2020
    registration starts at Wed Sep  9 08:21:42 2020
    astra recons finishes at Wed Sep  9 08:27:05 2020
1th iteration finishes at Wed Sep  9 08:27:16 2020
2th iteration starts at Wed Sep  9 08:27:16 2020
    astra recons starts at Wed Sep  9 08:27:18 2020
    astra recons finishes at Wed Sep  9 08:42:51 2020
    registration starts at Wed Sep  9 08:43:06 2020
    astra recons finishes at Wed Sep  9 08:48:30 2020
2th iteration finishes at Wed Sep  9 08:48:40 2020
3th iteration starts at Wed Sep  9 08:48:40 2020
    astra recon

In [None]:
print((ii-np.int(itr/2))*step+range_cen)
print(ii-np.int(itr/2))

In [None]:
print(np.sin(88*np.pi/180))

# ***13. in plane tilt correction***

In [15]:
itr = 5
range_cen = 0.5
step = 0.2 # degree
#method = 'sr'
#mode = 'RIGID_BODY' # 'TRANSLATION', 'RIGID_BODY', 'SCALED_ROTATION', 'AFFINE', 'BILINEAR'
"""

    translation
    rigid body (translation + rotation)
    scaled rotation (translation + rotation + scaling)
    affine (translation + rotation + scaling + shearing)
    bilinear (non-linear transformation; does not preserve straight lines)
    
"""

def register(ref, img, method='pcc', thres=1e-3, mode=None):
    if method == 'pcc':
        if thres is None:
            shift, _, _ = phase_cross_correlation(skf.sato(ref), skf.sato(img), upsample_factor=100)
            #print(shift)
            shifted_img = np.real(np.fft.ifftn(fourier_shift(np.fft.fftn(img), shift)))
            return shift, shifted_img
        else:
            mask = (skf.sato(img)>thres)
            shift = phase_cross_correlation(skf.sato(ref), skf.sato(img), upsample_factor=100, reference_mask=mask, overlap_ratio=0.3)
            #print(shift)
            shifted_img = np.real(np.fft.ifftn(fourier_shift(np.fft.fftn(img), shift)))
            return shift, shifted_img
    elif method == 'sr':
        if mode.upper() == 'TRANSLATION':
            sr = StackReg(StackReg.TRANSLATION)
        elif mode.upper() == 'RIGID_BODY':
            sr = StackReg(StackReg.RIGID_BODY)
        elif mode.upper() == 'SCALED_ROTATION':
            sr = StackReg(StackReg.SCALED_ROTATION)
        elif mode.upper() == 'AFFINE':
            sr = StackReg(StackReg.AFFINE)
        elif mode.upper() == 'BILINEAR':
            sr = StackReg(StackReg.BILINEAR)
        shift = sr.register(skf.sato(ref), skf.sato(img))
        shifted_img = sr.transform(img, tmat=shift)
        return shift, shifted_img
    
#ifn = Path('/media/xiao_usb/High_res_TXM_CT_2020/fly_scan_id_48753_sr_RIGID_BODY.h5')
#ofn = Path('/media/xiao_usb/High_res_TXM_CT_2020/fly_scan_id_48753_sr_RIGID_BODY.h5')
ifn = Path('/media/xiao_usb/High_res_TXM_CT_2020/fly_scan_id_48753.h5')
ofn = Path(f'/media/xiao_usb/High_res_TXM_CT_2020/fly_scan_id_48753_in-plane_tilt_step_{step}_cen_{range_cen}_out_plane_series.h5')

sli_s = 800
sli_e = 900

with h5py.File(ifn, 'r') as f:
    imgs = f['img_tomo'][:, sli_s:sli_e, :].astype(np.float32)
    bkg = f['img_bkg'][:, sli_s:sli_e, :].astype(np.float32)
    dark = f['img_dark'][:, sli_s:sli_e, :].astype(np.float32)
    theta = f['angle'][:]*np.pi/180.
    #method = f['/proj_corr/proj_corr_config/method'][()]
    #mode = f['/proj_corr/proj_corr_config/mode'][()] # 'TRANSLATION', 'RIGID_BODY', 'SCALED_ROTATION', 'AFFINE', 'BILINEAR'
    
with h5py.File(ofn, 'a') as f:
    if '/proj_out_plane_corr' in f:
        del f['/proj_out_plane_corr']
    g0 = f.create_group('/proj_out_plane_corr')
    g01 = g0.create_group('tilted_proj')
    g02 = g0.create_group('proj_corr_config')
    g02.create_dataset('method', data='astra')
    g02.create_dataset('mode', data='proj_vector correction')
    g02.create_dataset('step', data=step, dtype=np.float32)
    g02.create_dataset('angle_range', data=step, dtype=np.float32)
         

imgs[:] = ((imgs-dark.mean(axis=0))/(bkg.mean(axis=0)-dark.mean(axis=0)))[:]
imgs[:] = tomopy.prep.stripe.remove_all_stripe(imgs, la_size=161, sm_size=31)[:]
imgs[:] = -np.log(imgs)[:]
imgs[:] = np.where(np.isnan(imgs), 0, imgs)[:]
imgs[:] = np.where(np.isinf(imgs), 0, imgs)[:]

imgs = np.swapaxes(imgs, 0, 1)
ny, na, nx = imgs.shape

cen = 1243
offset = nx/2. - cen 
def shift_cen(img, offset):
    return np.real(np.fft.ifftn(fourier_shift(np.fft.fftn(img), [0, offset])))

n_cpu = os.cpu_count()
with mp.Pool(n_cpu-1) as pool:
    rlt = pool.starmap(shift_cen, [(imgs[:, ii, :], offset) for ii in np.int32(np.arange(na))])
    
pool.join()
pool.close()

for ii in range(na):
    imgs[:, ii, :] = rlt[ii][:, :]
del(rlt)
gc.collect()

cen = 1280

ori_imgs = copy.deepcopy(imgs)

proj_geom = astra.creators.create_proj_geom('parallel3d', 1., 1., ny, nx, theta)
vol_geom = astra.creators.create_vol_geom(nx, nx, ny)

for ii in range(itr):
    print(f"{ii}th iteration starts at {time.asctime()}")
    recon_id = astra.data3d.create('-vol', vol_geom, data=0)
    proj_geom_vec = astra.functions.geom_2vec(proj_geom)
    
    alpha = ((ii-np.int(itr/2))*step+range_cen)*np.pi/180
    
    """
    proj_geom_vec['Vectors'][:, 0] *= np.sin(np.pi/2-alpha)
    proj_geom_vec['Vectors'][:, 1] *= np.sin(np.pi/2-alpha)
    proj_geom_vec['Vectors'][:, 2] -= np.cos(np.pi/2-alpha)
    
    proj_geom_vec['Vectors'][:, 3] *= np.sin(np.pi/2-alpha)
    proj_geom_vec['Vectors'][:, 4] *= np.sin(np.pi/2-alpha)
    proj_geom_vec['Vectors'][:, 5] -= np.cos(np.pi/2-alpha)
    """
    
    proj_geom_vec['Vectors'][:, 6] *= np.cos(alpha)
    proj_geom_vec['Vectors'][:, 7] *= np.cos(alpha)
    proj_geom_vec['Vectors'][:, 8] += np.sin(alpha)
    
    proj_geom_vec['Vectors'][:, 9] = np.sin(alpha)*proj_geom_vec['Vectors'][:, 1]
    proj_geom_vec['Vectors'][:, 10] = np.sin(alpha)*proj_geom_vec['Vectors'][:, 3]
    proj_geom_vec['Vectors'][:, 11] *= np.cos(alpha)
    
    proj_id = astra.data3d.create('-sino', proj_geom_vec, imgs)

    print(f"    astra recons starts at {time.asctime()}")
    alg_cfg = astra.astra_dict('CGLS3D_CUDA')
    #alg_cfg = astra.astra_dict('BP3D_CUDA')

    alg_cfg['ProjectionDataId'] = proj_id
    alg_cfg['ReconstructionDataId'] = recon_id
    #alg_cfg['option'] = {}
    #alg_cfg['option']['MinConstraint'] = 0
    algorithm_id = astra.algorithm.create(alg_cfg)
    astra.algorithm.run(algorithm_id, 50)
    print(f"    astra recons finishes at {time.asctime()}")

    recon = astra.data3d.get(recon_id)
    #plt.figure(3)
    #plt.imshow(recon[50, :, :])

    fp_id, fp_data = astra.create_sino3d_gpu(recon, proj_geom, vol_geom)
    
    with h5py.File(ofn, 'a') as f:
        g0 = f['/proj_out_plane_corr/tilted_proj']       
        g0.create_dataset(f'tilt_{str((ii-np.int(itr/2))*step+range_cen).zfill(2)}_deg_proj', data=np.swapaxes(fp_data, 0, 1).astype(np.float32), dtype=np.float32)
        g0.create_dataset(f'tilt_{str((ii-np.int(itr/2))*step+range_cen).zfill(2)}_deg_recon', data=recon.astype(np.float32), dtype=np.float32)

    astra.algorithm.delete(algorithm_id)
    astra.data3d.delete(recon_id)
    astra.data3d.delete(proj_id)
    astra.data3d.delete(fp_id)
    astra.functions.clear()
    print(f"{ii}th iteration finishes at {time.asctime()}")



0th iteration starts at Thu Sep 10 06:41:18 2020
    astra recons starts at Thu Sep 10 06:41:20 2020
    astra recons finishes at Thu Sep 10 06:56:52 2020
0th iteration finishes at Thu Sep 10 06:57:13 2020
1th iteration starts at Thu Sep 10 06:57:13 2020
    astra recons starts at Thu Sep 10 06:57:15 2020
    astra recons finishes at Thu Sep 10 07:12:59 2020
1th iteration finishes at Thu Sep 10 07:13:19 2020
2th iteration starts at Thu Sep 10 07:13:19 2020
    astra recons starts at Thu Sep 10 07:13:22 2020
    astra recons finishes at Thu Sep 10 07:29:09 2020
2th iteration finishes at Thu Sep 10 07:29:33 2020
3th iteration starts at Thu Sep 10 07:29:33 2020
    astra recons starts at Thu Sep 10 07:29:35 2020
    astra recons finishes at Thu Sep 10 07:45:32 2020
3th iteration finishes at Thu Sep 10 07:45:56 2020
4th iteration starts at Thu Sep 10 07:45:56 2020
    astra recons starts at Thu Sep 10 07:45:58 2020
    astra recons finishes at Thu Sep 10 08:01:46 2020
4th iteration finishe