Skip to content

Commit

Permalink
Add subpixel-precision phase correlation function to feature module
Browse files Browse the repository at this point in the history
  • Loading branch information
msarahan committed Dec 15, 2014
1 parent 1344096 commit 8f82d73
Show file tree
Hide file tree
Showing 5 changed files with 413 additions and 0 deletions.
2 changes: 2 additions & 0 deletions TODO.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@ Version 0.12
`skimage.transform.PolynomialTransform._params`,
`skimage.transform.PiecewiseAffineTransform.affines_*` attributes
* Remove deprecated functions `skimage.filters.denoise_*`
* Add 3D phantom in `skimage.data`
* Add 3D test case of `skimage.feature.phase_correlate`
84 changes: 84 additions & 0 deletions doc/examples/plot_phase_correlate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
=====================================
Cross-Correlation (Phase Correlation)
=====================================
In this example, we use phase correlation to identify the relative shift
between two similar-sized images.
The ``phase_correlate`` function uses cross-correlation in Fourier space,
optionally employing an upsampled matrix-multiplication DFT to achieve
arbitrary subpixel precision. [1]_
.. [1] Manuel Guizar-Sicairos, Samuel T. Thurman, and James R. Fienup,
"Efficient subpixel image registration algorithms," Optics Letters 33,
156-158 (2008).
"""
import numpy as np
import matplotlib.pyplot as plt
import scipy.ndimage
import scipy.signal as ss

from skimage import data
from skimage.feature import phase_correlate
from skimage.feature.phase_correlate import _upsampled_dft

image = data.camera()
shift = (-2.4, 1.32)
# (-2.4, 1.32) pixel offset relative to reference coin
offset_image = scipy.ndimage.shift(image, shift)
print("Known offset (y, x):")
print(shift)

# pixel precision first
shift, error, diffphase = phase_correlate(image, offset_image)

fig, (ax1, ax2, ax3) = plt.subplots(ncols=3, figsize=(8, 3))

ax1.imshow(image)
ax1.set_axis_off()
ax1.set_title('Reference image')

ax2.imshow(offset_image)
ax2.set_axis_off()
ax2.set_title('Offset image')

# Calculate a cross-correlogram to show what the algorithm is doing behind
# the scenes
image_product = np.fft.fft2(image)*np.fft.fft2(offset_image).conj()
cc_image = np.fft.fftshift(np.fft.ifft2(image_product))
ax3.imshow(cc_image.real)
ax3.set_axis_off()
ax3.set_title("Cross-correlation")

plt.show()

print("Detected pixel offset (y, x):")
print(shift)

# subpixel precision
shift, error, diffphase = phase_correlate(image, offset_image, 100, ss.hamming)

fig, (ax1, ax2, ax3) = plt.subplots(ncols=3, figsize=(8, 3))

ax1.imshow(image)
ax1.set_axis_off()
ax1.set_title('Reference image')

ax2.imshow(offset_image)
ax2.set_axis_off()
ax2.set_title('Offset image')

# Calculate the upsampled DFT, again to show what the algorithm is doing
# behind the scenes.
cc_image = _upsampled_dft(image_product, 150, 100, (shift*100)+75).conj()
ax3.imshow(cc_image.real)
ax3.set_axis_off()
ax3.set_title("Supersampled XC sub-area")


plt.show()

print("Detected subpixel offset (y, x):")
print(shift)
2 changes: 2 additions & 0 deletions skimage/feature/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
hessian_matrix_eigvals, hessian_matrix_det)
from .corner_cy import corner_moravec, corner_orientations
from .template import match_template
from .phase_correlate import phase_correlate
from .brief import BRIEF
from .censure import CENSURE
from .orb import ORB
Expand Down Expand Up @@ -40,6 +41,7 @@
'corner_fast',
'corner_orientations',
'match_template',
'phase_correlate',
'BRIEF',
'CENSURE',
'ORB',
Expand Down
245 changes: 245 additions & 0 deletions skimage/feature/phase_correlate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@

# -*- coding: utf-8 -*-
"""
Port of Manuel Guizar's code from:
http://www.mathworks.com/matlabcentral/fileexchange/18401-efficient-subpixel-image-registration-by-cross-correlation
"""

import numpy as np


def _upsampled_dft(data, upsampled_region_size=None,
upsample_factor=1, axis_offsets=None):
"""
Upsampled DFT by matrix multiplication.
This code is intended to provide the same result as if the following
operations were performed:
- Embed the array "data" in an array that is ``upsample_factor`` times
larger in each dimension. ifftshift to bring the center of the
image to (1,1).
- Take the FFT of the larger array.
- Extract an ``[upsampled_region_size]`` region of the result, starting
with the ``[axis_offsets+1]`` element.
It achieves this result by computing the DFT in the output array without
the need to zeropad. Much faster and memory efficient than the zero-padded
FFT approach if ``upsampled_region_size`` is much smaller than
``data.size * upsample_factor``.
Parameters
----------
data : 2D ndarray
The input data array (DFT of original data) to upsample.
upsampled_region_size : integer or tuple of integers, optional
The size of the region to be sampled. If one integer is provided, it
is duplicated up to the dimensionality of ``data``. If None, this is
equal to ``data.shape``.
upsample_factor : integer, optional
The upsampling factor. Defaults to 1.
axis_offsets : tuple of integers, optional
The offsets of the region to be sampled. Defaults to None (uses
image center)
Returns
-------
output : 2D ndarray
The upsampled DFT of the specified region.
"""
if upsampled_region_size is None:
upsampled_region_size = data.shape
# if people pass in an integer, expand it to a list of equal-sized sections
elif not hasattr(upsampled_region_size, "__iter__"):
upsampled_region_size = [upsampled_region_size, ] * data.ndim
else:
if len(upsampled_region_size) != data.ndim:
raise ValueError("shape of upsampled region sizes must be equal "
"to input data's number of dimensions.")

if axis_offsets is None:
axis_offsets = [0, ] * data.ndim
elif not hasattr(axis_offsets, "__iter__"):
axis_offsets = [axis_offsets, ] * data.ndim
else:
if len(axis_offsets) != data.ndim:
raise ValueError("number of axis offsets must be equal to input "
"data's number of dimensions.")

col_kernel = np.exp(
(-1j * 2 * np.pi / (data.shape[1] * upsample_factor)) *
(np.fft.ifftshift(np.arange(data.shape[1]))[:, None] -
np.floor(data.shape[1] / 2)).dot(
np.arange(upsampled_region_size[1])[None, :] - axis_offsets[1])
)
row_kernel = np.exp(
(-1j * 2 * np.pi / (data.shape[0] * upsample_factor)) *
(np.arange(upsampled_region_size[0])[:, None] - axis_offsets[0]).dot(
np.fft.ifftshift(np.arange(data.shape[0]))[None, :] -
np.floor(data.shape[0] / 2))
)

return row_kernel.dot(data).dot(col_kernel)


def _nd_window(data, filter_function):
"""
Performs an in-place windowing on N-dimensional spatial-domain data.
This is done to mitigate boundary effects in the FFT.
Parameters
----------
data : ndarray
Input data to be windowed, modified in place.
filter_function : 1D window generation function
Function should accept one argument: the window length.
Example: scipy.signal.hamming
"""
for axis, axis_size in enumerate(data.shape):
# set up shape for numpy broadcasting
filter_shape = [1, ] * data.ndim
filter_shape[axis] = axis_size
window = filter_function(axis_size).reshape(filter_shape)
# scale the window intensities to maintain image intensity
np.power(window, (1.0 / data.ndim), output=window)
data *= window


def phase_correlate(src_image, target_image, upsample_factor=1,
filter_function=None):
"""
Efficient subpixel image registration by cross-correlation.
This code gives the same precision as the FFT upsampled cross-correlation
in a fraction of the computation time and with reduced memory requirements.
It obtains an initial estimate of the cross-correlation peak by an FFT and
then refines the shift estimation by upsampling the DFT only in a small
neighborhood of that estimate by means of a matrix-multiply DFT.
Parameters
----------
src_image : ndarray
Reference image. Real-only data treated as spatial domain, complex
data treated as frequency domain.
target_image : ndarray
Image to register. Must be same domain and same dimensionality as
``src_image``.
upsample_factor : int, optional
Upsampling factor. Images will be registered to within
``1 / upsample_factor`` of a pixel. For example
``upsample_factor == 20`` means the images will be registered
within 1/20th of a pixel. Default is 1 (no upsampling)
filter_function : 1D window generation function, optional
Window input data to mitigate boundary effects in the FFT.
Default is ``None`` (no filter). Example: ``scipy.signal.hamming``
Returns
-------
shifts : ndarray
Shift vector (in pixels) required to register ``target_image`` with
``src_image``. Axis ordering is consistent with numpy (e.g. Z, Y, X)
error : float
Translation invariant normalized RMS error between ``src_image`` and
``target_image``.
phasediff : float
Global phase difference between the two images (should be
zero if images are non-negative).
References
----------
.. [1] Manuel Guizar-Sicairos, Samuel T. Thurman, and James R. Fienup,
"Efficient subpixel image registration algorithms," Optics Letters 33,
156-158 (2008).
"""
# images must be the same shape
if src_image.shape != target_image.shape:
raise ValueError("Error: images must be same size for phase_correlate")

# only 2D data makes sense right now
if src_image.ndim != 2 and upsample_factor > 1:
raise NotImplementedError("Error: phase_correlate only supports "
"subpixel registration for 2D images")

# images must be both real, or both complex (fft data input)
if np.iscomplex(src_image).any() != np.iscomplex(target_image).any():
raise ValueError("Error: input images must be both real, or both "
"complex.")

# assume complex data is already in Fourier space
if np.iscomplex(src_image).any():
src_image_freq = src_image
target_image_freq = target_image
# real data needs to be fft'd.
else:
src_image = np.array(src_image, dtype=np.float64, copy=False)
target_image = np.array(target_image, dtype=np.float64, copy=False)
if filter_function:
# apply window function over each dimension using broadcasting
_nd_window(src_image, filter_function)
src_image_freq = np.fft.fftn(src_image)
target_image_freq = np.fft.fftn(target_image)

# Whole-pixel shift - Compute cross-correlation by an IFFT
shape = src_image_freq.shape
image_product = src_image_freq * target_image_freq.conj()
cross_correlation = np.fft.fftshift(np.fft.ifftn(image_product))

# Locate maximum
maxima = np.unravel_index(np.argmax(cross_correlation),
cross_correlation.shape)
midpoints = np.array([np.fix(axis_size / 2) for axis_size in shape])

shifts = np.array(maxima, dtype=np.float64)
shifts -= midpoints

if upsample_factor == 1:
rf0 = np.sum(np.abs(src_image_freq) ** 2) / (src_image_freq.size)
rg0 = np.sum(np.abs(target_image_freq) ** 2) / (target_image_freq.size)
CCmax = cross_correlation.max()
error = 1.0 - CCmax * CCmax.conj() / (rg0 * rf0)
error = np.sqrt(np.abs(error))
phasediff = np.arctan2(CCmax.imag, CCmax.real)
return shifts, error, phasediff
# If upsampling > 1, then refine estimate with matrix multiply DFT
else:
# Initial shift estimate in upsampled grid
shifts = np.round(shifts*upsample_factor) / upsample_factor
upsampled_region_size = np.ceil(upsample_factor * 1.5)
# Center of output array at dftshift + 1
dftshift = np.fix(upsampled_region_size / 2.0)
midpoint_product = np.product(midpoints)
normalization = (midpoint_product * upsample_factor ** 2)
# Matrix multiply DFT around the current shift estimate
sample_region_offset = shifts*upsample_factor + dftshift
cross_correlation = _upsampled_dft(image_product,
upsampled_region_size,
upsample_factor,
sample_region_offset).conj()
cross_correlation /= normalization
# Locate maximum and map back to original pixel grid
maxima = np.array(np.unravel_index(np.argmax(cross_correlation),
cross_correlation.shape),
dtype=np.float64)
maxima -= dftshift
shifts = shifts - maxima / upsample_factor
CCmax = cross_correlation.max()
rg00 = _upsampled_dft(src_image_freq * src_image_freq.conj(), 1,
upsample_factor)
rg00 /= normalization
rf00 = _upsampled_dft(target_image_freq * target_image_freq.conj(), 1,
upsample_factor)
rf00 /= normalization
error = 1.0 - CCmax * CCmax.conj() / (rg00 * rf00)
error = np.sqrt(np.abs(error))[0, 0]
phasediff = np.arctan2(CCmax.imag, CCmax.real)

# If its only one row or column the shift along that dimension has no
# effect. We set to zero.
for dim in range(src_image_freq.ndim):
if midpoints[dim] == 1:
shifts[dim] = 0

# the result is the shift necessary to apply to ``target_image`` to bring
# it into registration with ``src_image``. It will be opposite in sign to
# any shift that was applied to ``src_image`` to create ``target_image``.
return shifts, error, phasediff

0 comments on commit 8f82d73

Please sign in to comment.