Skip to content

Commit

Permalink
Add subpixel-precision image translation registration function to fea…
Browse files Browse the repository at this point in the history
…ture module
  • Loading branch information
msarahan committed Jan 28, 2015
1 parent 1344096 commit 9d97cca
Show file tree
Hide file tree
Showing 5 changed files with 425 additions and 0 deletions.
2 changes: 2 additions & 0 deletions TODO.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@ Version 0.12
`skimage.transform.PolynomialTransform._params`,
`skimage.transform.PiecewiseAffineTransform.affines_*` attributes
* Remove deprecated functions `skimage.filters.denoise_*`
* Add 3D phantom in `skimage.data`
* Add 3D test case of `skimage.feature.phase_correlate`
83 changes: 83 additions & 0 deletions doc/examples/plot_register_translation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
=====================================
Cross-Correlation (Phase Correlation)
=====================================
In this example, we use phase correlation to identify the relative shift
between two similar-sized images.
The ``register_translation`` function uses cross-correlation in Fourier space,
optionally employing an upsampled matrix-multiplication DFT to achieve
arbitrary subpixel precision. [1]_
.. [1] Manuel Guizar-Sicairos, Samuel T. Thurman, and James R. Fienup,
"Efficient subpixel image registration algorithms," Optics Letters 33,
156-158 (2008).
"""
import numpy as np
import matplotlib.pyplot as plt

from skimage import data
from skimage.feature import register_translation
from skimage.feature.register_translation import _upsampled_dft, fourier_shift

image = data.camera()
shift = (-2.4, 1.32)
# (-2.4, 1.32) pixel offset relative to reference coin
offset_image = fourier_shift(image, shift)
print("Known offset (y, x):")
print(shift)

# pixel precision first
shift, error, diffphase = register_translation(image, offset_image)

fig, (ax1, ax2, ax3) = plt.subplots(ncols=3, figsize=(8, 3))

ax1.imshow(image)
ax1.set_axis_off()
ax1.set_title('Reference image')

ax2.imshow(offset_image.real)
ax2.set_axis_off()
ax2.set_title('Offset image')

# View the output of a cross-correlation to show what the algorithm is
# doing behind the scenes
image_product = np.fft.fft2(image) * np.fft.fft2(offset_image).conj()
cc_image = np.fft.fftshift(np.fft.ifft2(image_product))
ax3.imshow(cc_image.real)
ax3.set_axis_off()
ax3.set_title("Cross-correlation")

plt.show()

print("Detected pixel offset (y, x):")
print(shift)

# subpixel precision
shift, error, diffphase = register_translation(image, offset_image, 100)

fig, (ax1, ax2, ax3) = plt.subplots(ncols=3, figsize=(8, 3))

ax1.imshow(image)
ax1.set_axis_off()
ax1.set_title('Reference image')

ax2.imshow(offset_image.real)
ax2.set_axis_off()
ax2.set_title('Offset image')

# Calculate the upsampled DFT, again to show what the algorithm is doing
# behind the scenes. Constants correspond to calculated values in routine.
# See source code for details.
cc_image = _upsampled_dft(image_product, 150, 100, (shift*100)+75).conj()
ax3.imshow(cc_image.real)
ax3.set_axis_off()
ax3.set_title("Supersampled XC sub-area")


plt.show()

print("Detected subpixel offset (y, x):")
print(shift)
2 changes: 2 additions & 0 deletions skimage/feature/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
hessian_matrix_eigvals, hessian_matrix_det)
from .corner_cy import corner_moravec, corner_orientations
from .template import match_template
from .register_translation import register_translation
from .brief import BRIEF
from .censure import CENSURE
from .orb import ORB
Expand Down Expand Up @@ -40,6 +41,7 @@
'corner_fast',
'corner_orientations',
'match_template',
'register_translation',
'BRIEF',
'CENSURE',
'ORB',
Expand Down
275 changes: 275 additions & 0 deletions skimage/feature/register_translation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
# -*- coding: utf-8 -*- """
"""
Port of Manuel Guizar's code from:
http://www.mathworks.com/matlabcentral/fileexchange/18401-efficient-subpixel-image-registration-by-cross-correlation
"""

import numpy as np


def _upsampled_dft(data, upsampled_region_size=None,
upsample_factor=1, axis_offsets=None):
"""
Upsampled DFT by matrix multiplication.
This code is intended to provide the same result as if the following
operations were performed:
- Embed the array "data" in an array that is ``upsample_factor`` times
larger in each dimension. ifftshift to bring the center of the
image to (1,1).
- Take the FFT of the larger array.
- Extract an ``[upsampled_region_size]`` region of the result, starting
with the ``[axis_offsets+1]`` element.
It achieves this result by computing the DFT in the output array without
the need to zeropad. Much faster and memory efficient than the zero-padded
FFT approach if ``upsampled_region_size`` is much smaller than
``data.size * upsample_factor``.
Parameters
----------
data : 2D ndarray
The input data array (DFT of original data) to upsample.
upsampled_region_size : integer or tuple of integers, optional
The size of the region to be sampled. If one integer is provided, it
is duplicated up to the dimensionality of ``data``. If None, this is
equal to ``data.shape``.
upsample_factor : integer, optional
The upsampling factor. Defaults to 1.
axis_offsets : tuple of integers, optional
The offsets of the region to be sampled. Defaults to None (uses
image center)
Returns
-------
output : 2D ndarray
The upsampled DFT of the specified region.
"""
if upsampled_region_size is None:
upsampled_region_size = data.shape
# if people pass in an integer, expand it to a list of equal-sized sections
elif not hasattr(upsampled_region_size, "__iter__"):
upsampled_region_size = [upsampled_region_size, ] * data.ndim
else:
if len(upsampled_region_size) != data.ndim:
raise ValueError("shape of upsampled region sizes must be equal "
"to input data's number of dimensions.")

if axis_offsets is None:
axis_offsets = [0, ] * data.ndim
elif not hasattr(axis_offsets, "__iter__"):
axis_offsets = [axis_offsets, ] * data.ndim
else:
if len(axis_offsets) != data.ndim:
raise ValueError("number of axis offsets must be equal to input "
"data's number of dimensions.")

col_kernel = np.exp(
(-1j * 2 * np.pi / (data.shape[1] * upsample_factor)) *
(np.fft.ifftshift(np.arange(data.shape[1]))[:, None] -
np.floor(data.shape[1] / 2)).dot(
np.arange(upsampled_region_size[1])[None, :] - axis_offsets[1])
)
row_kernel = np.exp(
(-1j * 2 * np.pi / (data.shape[0] * upsample_factor)) *
(np.arange(upsampled_region_size[0])[:, None] - axis_offsets[0]).dot(
np.fft.ifftshift(np.arange(data.shape[0]))[None, :] -
np.floor(data.shape[0] / 2))
)

return row_kernel.dot(data).dot(col_kernel)


def _compute_phasediff(cross_correlation_max):
"""
Compute global phase difference between the two images (should be
zero if images are non-negative).
Parameters
----------
cross_correlation_max : complex
The complex value of the cross correlation at its maximum point.
"""
return np.arctan2(cross_correlation_max.imag, cross_correlation_max.real)


def _compute_error(cross_correlation_max, src_amp, target_amp):
"""
Compute RMS error metric between ``src_image`` and ``target_image``.
Parameters
----------
cross_correlation_max : complex
The complex value of the cross correlation at its maximum point.
src_amp : float
The normalized average image intensity of the source image
target_amp : float
The normalized average image intensity of the target image
"""
error = 1.0 - cross_correlation_max * cross_correlation_max.conj() /\
(src_amp * target_amp)
return np.sqrt(np.abs(error))


def register_translation(src_image, target_image, upsample_factor=1,
space="real"):
"""
Efficient subpixel image translation registration by cross-correlation.
This code gives the same precision as the FFT upsampled cross-correlation
in a fraction of the computation time and with reduced memory requirements.
It obtains an initial estimate of the cross-correlation peak by an FFT and
then refines the shift estimation by upsampling the DFT only in a small
neighborhood of that estimate by means of a matrix-multiply DFT.
Parameters
----------
src_image : ndarray
Reference image.
target_image : ndarray
Image to register. Must be same dimensionality as ``src_image``.
upsample_factor : int, optional
Upsampling factor. Images will be registered to within
``1 / upsample_factor`` of a pixel. For example
``upsample_factor == 20`` means the images will be registered
within 1/20th of a pixel. Default is 1 (no upsampling)
space : string, one of ``real`` or ``fourier``
Defines how the algorithm interprets input data. ``real`` means data
will be FFT'd to compute the correlation, while ``fourier`` data will
bypass both FFT and any windowing specified by ``filter_function``.
Returns
-------
shifts : ndarray
Shift vector (in pixels) required to register ``target_image`` with
``src_image``. Axis ordering is consistent with numpy (e.g. Z, Y, X)
error : float
Translation invariant normalized RMS error between ``src_image`` and
``target_image``.
phasediff : float
Global phase difference between the two images (should be
zero if images are non-negative).
References
----------
.. [1] Manuel Guizar-Sicairos, Samuel T. Thurman, and James R. Fienup,
"Efficient subpixel image registration algorithms,"
Optics Letters 33, 156-158 (2008).
"""
# images must be the same shape
if src_image.shape != target_image.shape:
raise ValueError("Error: images must be same size for "
"register_translation")

# only 2D data makes sense right now
if src_image.ndim != 2 and upsample_factor > 1:
raise NotImplementedError("Error: register_translation only supports "
"subpixel registration for 2D images")

# assume complex data is already in Fourier space
if space.lower() == 'fourier':
src_freq = src_image
target_freq = target_image
# real data needs to be fft'd.
elif space.lower() == 'real':
src_image = np.array(src_image, dtype=np.complex128, copy=False)
target_image = np.array(target_image, dtype=np.complex128, copy=False)
src_freq = np.fft.fftn(src_image)
target_freq = np.fft.fftn(target_image)
else:
raise ValueError("Error: register_translation only knows the \"real\" "
"and \"fourier\" values for the ``space`` argument.")

# Whole-pixel shift - Compute cross-correlation by an IFFT
shape = src_freq.shape
image_product = src_freq * target_freq.conj()
cross_correlation = np.fft.fftshift(np.fft.ifftn(image_product))

# Locate maximum
maxima = np.unravel_index(np.argmax(cross_correlation),
cross_correlation.shape)
midpoints = np.array([np.fix(axis_size / 2) for axis_size in shape])

shifts = np.array(maxima, dtype=np.float64)
shifts -= midpoints

if upsample_factor == 1:
src_amp = np.sum(np.abs(src_freq) ** 2) / src_freq.size
target_amp = np.sum(np.abs(target_freq) ** 2) / target_freq.size
CCmax = cross_correlation.max()
# If upsampling > 1, then refine estimate with matrix multiply DFT
else:
# Initial shift estimate in upsampled grid
shifts = np.round(shifts * upsample_factor) / upsample_factor
upsampled_region_size = np.ceil(upsample_factor * 1.5)
# Center of output array at dftshift + 1
dftshift = np.fix(upsampled_region_size / 2.0)
midpoint_product = np.product(midpoints)
normalization = (midpoint_product * upsample_factor ** 2)
# Matrix multiply DFT around the current shift estimate
sample_region_offset = shifts*upsample_factor + dftshift
cross_correlation = _upsampled_dft(image_product,
upsampled_region_size,
upsample_factor,
sample_region_offset).conj()
cross_correlation /= normalization
# Locate maximum and map back to original pixel grid
maxima = np.array(np.unravel_index(np.argmax(cross_correlation),
cross_correlation.shape),
dtype=np.float64)
maxima -= dftshift
shifts = shifts - maxima / upsample_factor
CCmax = cross_correlation.max()
src_amp = _upsampled_dft(src_freq * src_freq.conj(),
1, upsample_factor)[0, 0]
src_amp /= normalization
target_amp = _upsampled_dft(target_freq * target_freq.conj(),
1,
upsample_factor)[0, 0]
target_amp /= normalization

# If its only one row or column the shift along that dimension has no
# effect. We set to zero.
for dim in range(src_freq.ndim):
if midpoints[dim] == 1:
shifts[dim] = 0

return shifts, _compute_error(CCmax, src_amp, target_amp),\
_compute_phasediff(CCmax)


# TODO: this is here for the sake of testing the registration functions. It is
# more accurate than scipy.ndimage.shift, which uses spline interpolation
# to achieve the same purpose. However, in its current state, this
# function is far more limited than scipy.ndimage.shift. Improvements
# include choices on how to handle boundary wrap-around, and expansion to
# n-dimensions.
def fourier_shift(image, shift):
"""
Shift a real-space 2D image by shift by applying shift to phase in Fourier
space.
Parameters
----------
image : ndarray
Real-space 2D image to be shifted.
shift : length 2 array-like of floats
Shift to be applied to image. Order is row-major (y, x).
Returns
-------
out : ndarray
Shifted image. Boundaries wrap around.
"""
if image.ndim > 2:
raise NotImplementedError("Error: fourier_shift only supports "
" 2D images")
rows = np.fft.ifftshift(np.arange(-np.floor(image.shape[0] / 2),
np.ceil(image.shape[0] / 2)))
cols = np.fft.ifftshift(np.arange(-np.floor(image.shape[1] / 2),
np.ceil(image.shape[1] / 2)))
cols, rows = np.meshgrid(cols, rows)
out = np.fft.ifft2(np.fft.fft2(image) * np.exp(1j * 2 * np.pi *
(shift[0] * rows / image.shape[0] +
shift[1] * cols / image.shape[1])))
return out

0 comments on commit 9d97cca

Please sign in to comment.