Skip to content

Commit

Permalink
Merge pull request #523 from nghia-vo/master
Browse files Browse the repository at this point in the history
Enable parallel computing for find_center_vo, allow 2D sinogram input…
  • Loading branch information
carterbox committed Oct 15, 2020
2 parents 6b053f3 + f6096b5 commit 78ecc9d
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 62 deletions.
130 changes: 75 additions & 55 deletions source/tomopy/recon/rotation.py
Expand Up @@ -53,17 +53,20 @@
from __future__ import (absolute_import, division, print_function,
unicode_literals)

import logging
import os.path

import numpy as np
from scipy import ndimage
from tomopy.util.misc import fft2, write_tiff
from scipy.optimize import minimize
from skimage.registration import phase_cross_correlation

from tomopy.misc.corr import circ_mask
from tomopy.misc.morph import downsample
from tomopy.recon.algorithm import recon
import tomopy.util.dtype as dtype
import os.path
import logging
from tomopy.util.misc import fft2, write_tiff
from tomopy.util.mproc import distribute_jobs

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -203,18 +206,19 @@ def _find_center_cost(


def find_center_vo(tomo, ind=None, smin=-50, smax=50, srad=6, step=0.25,
ratio=0.5, drop=20):
ratio=0.5, drop=20, ncore=None):
"""
Find rotation axis location using Nghia Vo's method. :cite:`Vo:14`.
Parameters
----------
tomo : ndarray
3D tomographic data.
3D tomographic data or a 2D sinogram.
ind : int, optional
Index of the slice to be used for reconstruction.
smin, smax : int, optional
Coarse search radius. Reference to the horizontal center of the sinogram.
Coarse search radius. Reference to the horizontal center of
the sinogram.
srad : float, optional
Fine search radius.
step : float, optional
Expand All @@ -224,13 +228,18 @@ def find_center_vo(tomo, ind=None, smin=-50, smax=50, srad=6, step=0.25,
It's used to generate the mask.
drop : int, optional
Drop lines around vertical center of the mask.
ncore : int, optional
Number of cores that will be assigned to jobs.
Returns
-------
float
Rotation axis location.
"""
tomo = dtype.as_float32(tomo)
if tomo.ndim == 2:
tomo = np.expand_dims(tomo, 1)
ind = 0
(depth, height, width) = tomo.shape
if ind is None:
ind = height // 2
Expand All @@ -245,66 +254,86 @@ def find_center_vo(tomo, ind=None, smin=-50, smax=50, srad=6, step=0.25,
# Denoising
# There's a critical reason to use different window sizes
# between coarse and fine search.
_tomo_cs = ndimage.filters.gaussian_filter(_tomo, (3, 1))
_tomo_fs = ndimage.filters.gaussian_filter(_tomo, (2, 2))
_tomo_cs = ndimage.filters.gaussian_filter(_tomo, (3, 1), mode='reflect')
_tomo_fs = ndimage.filters.gaussian_filter(_tomo, (2, 2), mode='reflect')

# Coarse and fine searches for finding the rotation center.
if _tomo.shape[0] * _tomo.shape[1] > 4e6: # If data is large (>2kx2k)
_tomo_coarse = downsample(
np.expand_dims(_tomo_cs, 1), level=2)[:, 0, :]
init_cen = _search_coarse(
_tomo_coarse, smin / 4.0, smax / 4.0, ratio, drop)
_tomo_coarse, smin / 4.0, smax / 4.0, ratio, drop, ncore=ncore)
fine_cen = _search_fine(_tomo_fs, srad, step,
init_cen * 4, ratio, drop)
init_cen * 4.0, ratio, drop, ncore=ncore)
else:
init_cen = _search_coarse(_tomo_cs, smin, smax, ratio, drop)
fine_cen = _search_fine(_tomo_fs, srad, step, init_cen, ratio, drop)

init_cen = _search_coarse(_tomo_cs, smin, smax, ratio, drop,
ncore=ncore)
fine_cen = _search_fine(_tomo_fs, srad, step,
init_cen, ratio, drop, ncore=ncore)
logger.debug('Rotation center search finished: %i', fine_cen)
return fine_cen


def _search_coarse(sino, smin, smax, ratio, drop):
def _calculate_metric(shift_col, sino1, sino2, sino3, mask):
"""
Metric calculation.
"""
shift_col = 1.0 * np.squeeze(shift_col)
if np.abs(shift_col - np.floor(shift_col)) == 0.0:
shift_col = int(shift_col)
sino_shift = np.roll(sino2, shift_col, axis=1)
if shift_col >= 0:
sino_shift[:, :shift_col] = sino3[:, :shift_col]
else:
sino_shift[:, shift_col:] = sino3[:, shift_col:]
mat = np.vstack((sino1, sino_shift))
else:
sino_shift = ndimage.interpolation.shift(
sino2, (0, shift_col), order=3, prefilter=True)
if shift_col >= 0:
shift_int = int(np.ceil(shift_col))
sino_shift[:, :shift_int] = sino3[:, :shift_int]
else:
shift_int = int(np.floor(shift_col))
sino_shift[:, shift_int:] = sino3[:, shift_int:]
mat = np.vstack((sino1, sino_shift))
metric = np.mean(
np.abs(np.fft.fftshift(fft2(mat))) * mask)
return np.asarray([metric], dtype=np.float32)


def _search_coarse(sino, smin, smax, ratio, drop, ncore=None):
"""
Coarse search for finding the rotation center.
"""
(nrow, ncol) = sino.shape
cen_fliplr = (ncol - 1.0) / 2.0
smin = np.int16(np.clip(smin + cen_fliplr, 0, ncol - 1) - cen_fliplr)
smax = np.int16(np.clip(smax + cen_fliplr, 0, ncol - 1) - cen_fliplr)
# Flip left-right the [0:Pi ] sinogram to make a full [0;2Pi] sinogram
start_cor = ncol // 2 + smin
stop_cor = ncol // 2 + smax
flip_sino = np.fliplr(sino)
# Below image is used for compensating the shift of the [Pi;2Pi] sinogram
# It helps to avoid local minima.
comp_sino = np.flipud(sino)
list_shift = np.arange(smin, smax + 1)
list_metric = np.zeros(len(list_shift), dtype='float32')
comp_sino = np.flipud(sino) # Used to avoid local minima
list_cor = np.arange(start_cor, stop_cor + 0.5, 0.5)
list_metric = np.zeros(len(list_cor), dtype=np.float32)
mask = _create_mask(2 * nrow, ncol, 0.5 * ratio * ncol, drop)
sino_sino = np.vstack((sino, flip_sino))
abs_fft2_sino = np.empty_like(sino_sino)
for i in list_shift:
_sino = sino_sino[nrow:]
_sino[...] = np.roll(flip_sino, i, axis=1)
if i >= 0:
_sino[:, 0:i] = comp_sino[:, 0:i]
else:
_sino[:, i:] = comp_sino[:, i:]
fft2sino = np.fft.fftshift(fft2(sino_sino))
np.abs(fft2sino, out=abs_fft2_sino)
abs_fft2_sino *= mask
list_metric[i - smin] = abs_fft2_sino.mean()
list_shift = 2.0 * (list_cor - cen_fliplr)
list_metric = distribute_jobs(np.float32(list_shift),
_calculate_metric, axis=0,
args=(sino, flip_sino, comp_sino, mask),
ncore=ncore, nchunk=1)
minpos = np.argmin(list_metric)
if minpos == 0:
logger.debug('WARNING!!!Global minimum is out of searching range')
logger.debug('Please extend smin: %i', smin)
if minpos == len(list_metric) - 1:
logger.debug('WARNING!!!Global minimum is out of searching range')
logger.debug('Please extend smax: %i', smax)
init_cen = cen_fliplr + list_shift[minpos] / 2.0
return init_cen
cor = list_cor[minpos]
return cor


def _search_fine(sino, srad, step, init_cen, ratio, drop):
def _search_fine(sino, srad, step, init_cen, ratio, drop, ncore=None):
"""
Fine search for finding the rotation center.
"""
Expand All @@ -315,22 +344,13 @@ def _search_fine(sino, srad, step, init_cen, ratio, drop):
init_cen = np.clip(init_cen, srad, ncol - srad - 1)
list_cor = init_cen + np.arange(-srad, srad + step, step)
flip_sino = np.fliplr(sino)
comp_sino = np.flipud(sino) # Used to avoid local minima
list_metric = np.zeros(len(list_cor), dtype=np.float32)
comp_sino = np.flipud(sino)
mask = _create_mask(2 * nrow, ncol, 0.5 * ratio * ncol, drop)
for i, cor in enumerate(list_cor):
shift = 2.0 * (cor - cen_fliplr)
sino_shift = ndimage.interpolation.shift(
flip_sino, (0, shift), order=3, prefilter=True)
if shift >= 0:
shift_int = np.int16(np.ceil(shift))
sino_shift[:, :shift_int] = comp_sino[:, :shift_int]
else:
shift_int = np.int16(np.floor(shift))
sino_shift[:, shift_int:] = comp_sino[:, shift_int:]
sinojoin = np.vstack((sino, sino_shift))
list_metric[i] = np.mean(np.abs(
np.fft.fftshift(fft2(sinojoin))) * mask)
list_shift = 2.0 * (list_cor - cen_fliplr)
list_metric = distribute_jobs(np.float32(list_shift),
_calculate_metric, axis=0,
args=(sino, flip_sino, comp_sino, mask),
ncore=ncore, nchunk=1)
cor = list_cor[np.argmin(list_metric)]
return cor

Expand Down Expand Up @@ -359,10 +379,10 @@ def _create_mask(nrow, ncol, radius, drop):
dv = (nrow - 1.0) / (nrow * 2.0 * np.pi)
cen_row = np.int16(np.ceil(nrow / 2.0) - 1)
cen_col = np.int16(np.ceil(ncol / 2.0) - 1)
drop = min(drop, np.int16(np.ceil(0.1 * nrow)))
drop = min(drop, np.int16(np.ceil(0.05 * nrow)))
mask = np.zeros((nrow, ncol), dtype='float32')
for i in range(nrow):
pos = np.int16(np.round(((i - cen_row) * dv / radius) / du))
pos = np.int16(np.ceil(((i - cen_row) * dv / radius) / du))
(pos1, pos2) = np.clip(np.sort(
(-pos + cen_col, pos + cen_col)), 0, ncol - 1)
mask[i, pos1:pos2 + 1] = 1.0
Expand Down Expand Up @@ -487,8 +507,8 @@ def write_center(
'grad'
Gradient descent method with a constant step size
'tikh'
Tikhonov regularization with identity Tikhonov matrix.
Tikhonov regularization with identity Tikhonov matrix.
filter_name : str, optional
Name of the filter for analytic reconstruction.
Expand Down
14 changes: 7 additions & 7 deletions test/test_tomopy/test_recon/test_rotation.py
Expand Up @@ -56,6 +56,7 @@
#from tomopy.util.mproc import get_rank, get_nproc, barrier
import numpy as np
from scipy.ndimage.interpolation import shift as image_shift
from scipy.ndimage import zoom
import os.path
import shutil
from numpy.testing import assert_array_equal as assert_equals
Expand Down Expand Up @@ -103,16 +104,15 @@ def test_find_center(self):

def test_find_center_vo(self):
sim = read_file('sinogram.npy')
cen = find_center_vo(sim)
assert_allclose(cen, 45.28, rtol=0.015)
cen = find_center_vo(sim, smin=-10, smax=10)
assert_allclose(cen, 44.75, rtol=0.25)

def test_find_center_vo_with_downsampling(self):
sim = read_file('sinogram.npy')
np.pad(
sim, ((1000, 1000), (0, 0), (1000, 1000)),
mode="constant", constant_values=0)
cen = find_center_vo(sim)
assert_allclose(cen, 45.28, rtol=0.015)
sim = zoom(sim[:, 0, :], (45, 22), order=3, mode='reflect')
sim = np.expand_dims(sim, 1)
cen = find_center_vo(sim, smin=-10, smax=10)
assert_allclose(cen, 1002.0, rtol=0.25)

def test_find_center_pc(self):
proj_0 = read_file('projection.npy')
Expand Down

0 comments on commit 78ecc9d

Please sign in to comment.