Skip to content

Commit

Permalink
ENH: Modularized uses of fft/ifft/fft2/ifft2 into tomopy.util.misc
Browse files Browse the repository at this point in the history
Provided implementations of these function using mkl_fft or pyfftw or numpy.fft

Implemented support for in-place FFT operation.

Optimized constructon of _reciprocal_grid and _reciprocal_ccord
  • Loading branch information
oleksandr-pavlyk committed Oct 27, 2017
1 parent 01e3396 commit 4f492a1
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 27 deletions.
33 changes: 15 additions & 18 deletions tomopy/prep/phase.py
Expand Up @@ -54,7 +54,9 @@
unicode_literals)

import numpy as np
import pyfftw
from tomopy.util.misc import (fft2, ifft2)


import tomopy.util.mproc as mproc
import logging

Expand Down Expand Up @@ -120,7 +122,7 @@ def retrieve_phase(
phase_filter = np.fft.fftshift(
_paganin_filter_factor(energy, dist, alpha, w2))

prj = val * np.ones((dy + 2 * py, dz + 2 * pz), dtype='float32')
prj = np.full((dy + 2 * py, dz + 2 * pz), val, dtype='float32')
arr = mproc.distribute_jobs(
tomo,
func=_retrieve_phase,
Expand All @@ -134,29 +136,21 @@ def retrieve_phase(
def _retrieve_phase(tomo, phase_filter, px, py, prj, pad):
dx, dy, dz = tomo.shape
num_jobs = tomo.shape[0]
normalized_phase_filter = phase_filter / phase_filter.max()
for m in range(num_jobs):
prj[px:dy + px, py:dz + py] = tomo[m]
prj[:px] = prj[px]
prj[-px:] = prj[-px-1]
prj[:, :py] = prj[:, py][:, np.newaxis]
prj[:, -py:] = prj[:, -py-1][:, np.newaxis]
fproj = pyfftw.interfaces.numpy_fft.fft2(
prj, planner_effort=_plan_effort(num_jobs))
filtproj = np.multiply(phase_filter, fproj)
proj = np.real(pyfftw.interfaces.numpy_fft.ifft2(
filtproj, planner_effort=_plan_effort(num_jobs))
) / phase_filter.max()
fproj = fft2(prj, extra_info=num_jobs)
fproj *= normalized_phase_filter
proj = np.real(ifft2(fproj, extra_info=num_jobs, overwrite_input=True))
if pad:
proj = proj[px:dy + px, py:dz + py]
tomo[m] = proj


def _plan_effort(num_jobs):
if num_jobs > 10:
return 'FFTW_MEASURE'
else:
return 'FFTW_ESTIMATE'


def _calc_pad(tomo, pixel_size, dist, energy, pad):
"""
Expand Down Expand Up @@ -226,8 +220,9 @@ def _reciprocal_grid(pixel_size, nx, ny):
# Sampling in reciprocal space.
indx = _reciprocal_coord(pixel_size, nx)
indy = _reciprocal_coord(pixel_size, ny)
du, dv = np.meshgrid(indy, indx)
return np.square(du) + np.square(dv)
np.square(indx, out=indx)
np.square(indy, out=indy)
return np.add.outer(indx, indy)


def _reciprocal_coord(pixel_size, num_grid):
Expand All @@ -247,5 +242,7 @@ def _reciprocal_coord(pixel_size, num_grid):
ndarray
Grid coordinates.
"""
return (1 / ((num_grid - 1) * pixel_size)) * \
np.arange(-(num_grid - 1) * 0.5, num_grid * 0.5)
n = num_grid - 1
rc = np.arange(-n, num_grid, 2, dtype = np.float32)
rc *= 0.5 / (n * pixel_size)
return rc
13 changes: 5 additions & 8 deletions tomopy/prep/stripe.py
Expand Up @@ -55,11 +55,11 @@

import numpy as np
import pywt
import pyfftw
import tomopy.prep.phase as phase
import tomopy.util.extern as extern
import tomopy.util.mproc as mproc
import tomopy.util.dtype as dtype
from tomopy.util.misc import (fft, ifft)
import logging

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -143,19 +143,16 @@ def _remove_stripe_fw(tomo, level, wname, sigma, pad):
# FFT transform of horizontal frequency bands.
for n in range(level):
# FFT
fcV = np.fft.fftshift(pyfftw.interfaces.numpy_fft.fft(
cV[n], axis=0, planner_effort=phase._plan_effort(num_jobs)))
fcV = np.fft.fftshift(fft(cV[n], axis=0, extra_info=num_jobs))
my, mx = fcV.shape

# Damping of ring artifact information.
y_hat = (np.arange(-my, my, 2, dtype='float32') + 1) / 2
damp = 1 - np.exp(-np.power(y_hat, 2) / (2 * np.power(sigma, 2)))
fcV = np.multiply(fcV, np.transpose(np.tile(damp, (mx, 1))))
damp = -np.expm1(-np.square(y_hat) / (2 * np.square(sigma)))
fcV *= np.transpose(np.tile(damp, (mx, 1)))

# Inverse FFT.
cV[n] = np.real(pyfftw.interfaces.numpy_fft.ifft(
np.fft.ifftshift(fcV), axis=0,
planner_effort=phase._plan_effort(num_jobs)))
cV[n] = np.real(ifft(np.fft.ifftshift(fcV), axis=0, extra_info=num_jobs))

# Wavelet reconstruction.
for n in range(level)[::-1]:
Expand Down
84 changes: 83 additions & 1 deletion tomopy/util/misc.py
Expand Up @@ -56,13 +56,15 @@
import logging
import warnings



logger = logging.getLogger(__name__)


__author__ = "Doga Gursoy"
__copyright__ = "Copyright (c) 2016, UChicago Argonne, LLC."
__docformat__ = 'restructuredtext en'
__all__ = ['deprecated']
__all__ = ['deprecated', 'fft2', 'ifft2', 'fft', 'ifft']


def deprecated(func, msg=None):
Expand All @@ -81,3 +83,83 @@ def new_func(*args, **kwargs):
new_func.__doc__ = func.__doc__
new_func.__dict__.update(func.__dict__)
return new_func


try:
import mkl_fft
fft_impl = 'mkl_fft'
logger.debug('FFT implementation is mkl_fft')
except ImportError:
try:
import pyfftw
fft_impl = 'pyfftw'
logger.debug('FFT implementation is pyfftw')
except ImportError:
import np.fft
fft_impl = 'numpy.fft'
logger.debug('FFT implementation is numpy.fft')


if fft_impl == 'mkl_fft':
def fft(x, n=None, axis=-1, overwrite_input=False, extra_info=None):
return mkl_fft.fft(x, n=n, axis=axis, overwrite_x=overwrite_input)


def ifft(x, n=None, axis=-1, overwrite_input=False, extra_info=None):
return mkl_fft.ifft(x, n=n, axis=axis, overwrite_x=overwrite_input)


def fft2(x, s=None, axes=(-2,-1), overwrite_input=False, extra_info=None):
return mkl_fft.fft2(x, shape=s, axes=axes, overwrite_x=overwrite_input)


def ifft2(x, s=None, axes=(-2,-1), overwrite_input=False, extra_info=None):
return mkl_fft.ifft2(x, shape=s, axes=axes, overwrite_x=overwrite_input)

elif fft_impl == 'pyfftw':
def _plan_effort(num_jobs):
if not num_jobs:
return 'FFTW_MEASURE'
if num_jobs > 10:
return 'FFTW_MEASURE'
else:
return 'FFTW_ESTIMATE'

def fft(x, n=None, axis=-1, overwrite_input=False, extra_info=None):
return pyfftw.interfaces.numpy_fft.fft(x, n=n, axis=axis,
overwrite_input=overwrite_input,
planner_effort=_plan_effort(extra_info))


def ifft(x, n=None, axis=-1, overwrite_input=False, extra_info=None):
return pyfftw.interfaces.numpy_fft.ifft(x, n=n, axis=axis,
overwrite_input=overwrite_input,
planner_effort=_plan_effort(extra_info))


def fft2(x, s=None, axes=(-2,-1), overwrite_input=False, extra_info=None):
return pyfftw.interfaces.numpy_fft.fft2(x, s=s, axes=axes,
overwrite_input=overwrite_input,
planner_effort=_plan_effort(extra_info))


def ifft2(x, s=None, axes=(-2,-1), overwrite_input=False, extra_info=None):
return pyfftw.interfaces.numpy_fft.ifft2(x, s=s, axes=axes,
overwrite_input=overwrite_input,
planner_effort=_plan_effort(extra_info))

else:
def fft(x, n=None, axis=-1, overwrite_input=False, extra_info=None):
return np.fft.fft(x, n=n, axis=axis)


def ifft(x, n=None, axis=-1, overwrite_input=False, extra_info=None):
return np.fft.ifft(x, n=n, axis=axis)


def fft2(x, s=None, axes=(-2,-1), overwrite_input=False, extra_info=None):
return np.fft.fft2(x, s=s, axes=axes)


def ifft2(x, s=None, axes=(-2,-1), overwrite_input=False, extra_info=None):
return np.fft.ifft2(x, s=s, axes=axes)

0 comments on commit 4f492a1

Please sign in to comment.