From 9d97cca404a5b777dfaa747a02291f80222c871f Mon Sep 17 00:00:00 2001 From: Michael Sarahan Date: Mon, 28 Apr 2014 08:03:15 -0700 Subject: [PATCH] Add subpixel-precision image translation registration function to feature module --- TODO.txt | 2 + doc/examples/plot_register_translation.py | 83 ++++++ skimage/feature/__init__.py | 2 + skimage/feature/register_translation.py | 275 ++++++++++++++++++ .../tests/test_register_translation.py | 63 ++++ 5 files changed, 425 insertions(+) create mode 100644 doc/examples/plot_register_translation.py create mode 100644 skimage/feature/register_translation.py create mode 100644 skimage/feature/tests/test_register_translation.py diff --git a/TODO.txt b/TODO.txt index d6d95586987..98e8653b0dc 100644 --- a/TODO.txt +++ b/TODO.txt @@ -27,3 +27,5 @@ Version 0.12 `skimage.transform.PolynomialTransform._params`, `skimage.transform.PiecewiseAffineTransform.affines_*` attributes * Remove deprecated functions `skimage.filters.denoise_*` +* Add 3D phantom in `skimage.data` +* Add 3D test case of `skimage.feature.phase_correlate` diff --git a/doc/examples/plot_register_translation.py b/doc/examples/plot_register_translation.py new file mode 100644 index 00000000000..2a30f291897 --- /dev/null +++ b/doc/examples/plot_register_translation.py @@ -0,0 +1,83 @@ +""" +===================================== +Cross-Correlation (Phase Correlation) +===================================== + +In this example, we use phase correlation to identify the relative shift +between two similar-sized images. + +The ``register_translation`` function uses cross-correlation in Fourier space, +optionally employing an upsampled matrix-multiplication DFT to achieve +arbitrary subpixel precision. [1]_ + +.. [1] Manuel Guizar-Sicairos, Samuel T. Thurman, and James R. Fienup, + "Efficient subpixel image registration algorithms," Optics Letters 33, + 156-158 (2008). + +""" +import numpy as np +import matplotlib.pyplot as plt + +from skimage import data +from skimage.feature import register_translation +from skimage.feature.register_translation import _upsampled_dft, fourier_shift + +image = data.camera() +shift = (-2.4, 1.32) +# (-2.4, 1.32) pixel offset relative to reference coin +offset_image = fourier_shift(image, shift) +print("Known offset (y, x):") +print(shift) + +# pixel precision first +shift, error, diffphase = register_translation(image, offset_image) + +fig, (ax1, ax2, ax3) = plt.subplots(ncols=3, figsize=(8, 3)) + +ax1.imshow(image) +ax1.set_axis_off() +ax1.set_title('Reference image') + +ax2.imshow(offset_image.real) +ax2.set_axis_off() +ax2.set_title('Offset image') + +# View the output of a cross-correlation to show what the algorithm is +# doing behind the scenes +image_product = np.fft.fft2(image) * np.fft.fft2(offset_image).conj() +cc_image = np.fft.fftshift(np.fft.ifft2(image_product)) +ax3.imshow(cc_image.real) +ax3.set_axis_off() +ax3.set_title("Cross-correlation") + +plt.show() + +print("Detected pixel offset (y, x):") +print(shift) + +# subpixel precision +shift, error, diffphase = register_translation(image, offset_image, 100) + +fig, (ax1, ax2, ax3) = plt.subplots(ncols=3, figsize=(8, 3)) + +ax1.imshow(image) +ax1.set_axis_off() +ax1.set_title('Reference image') + +ax2.imshow(offset_image.real) +ax2.set_axis_off() +ax2.set_title('Offset image') + +# Calculate the upsampled DFT, again to show what the algorithm is doing +# behind the scenes. Constants correspond to calculated values in routine. +# See source code for details. +cc_image = _upsampled_dft(image_product, 150, 100, (shift*100)+75).conj() +ax3.imshow(cc_image.real) +ax3.set_axis_off() +ax3.set_title("Supersampled XC sub-area") + + +plt.show() + +print("Detected subpixel offset (y, x):") +print(shift) diff --git a/skimage/feature/__init__.py b/skimage/feature/__init__.py index 90bb6e539f6..e51be11fde7 100644 --- a/skimage/feature/__init__.py +++ b/skimage/feature/__init__.py @@ -10,6 +10,7 @@ hessian_matrix_eigvals, hessian_matrix_det) from .corner_cy import corner_moravec, corner_orientations from .template import match_template +from .register_translation import register_translation from .brief import BRIEF from .censure import CENSURE from .orb import ORB @@ -40,6 +41,7 @@ 'corner_fast', 'corner_orientations', 'match_template', + 'register_translation', 'BRIEF', 'CENSURE', 'ORB', diff --git a/skimage/feature/register_translation.py b/skimage/feature/register_translation.py new file mode 100644 index 00000000000..947929398d5 --- /dev/null +++ b/skimage/feature/register_translation.py @@ -0,0 +1,275 @@ +# -*- coding: utf-8 -*- """ +""" +Port of Manuel Guizar's code from: +http://www.mathworks.com/matlabcentral/fileexchange/18401-efficient-subpixel-image-registration-by-cross-correlation +""" + +import numpy as np + + +def _upsampled_dft(data, upsampled_region_size=None, + upsample_factor=1, axis_offsets=None): + """ + Upsampled DFT by matrix multiplication. + + This code is intended to provide the same result as if the following + operations were performed: + - Embed the array "data" in an array that is ``upsample_factor`` times + larger in each dimension. ifftshift to bring the center of the + image to (1,1). + - Take the FFT of the larger array. + - Extract an ``[upsampled_region_size]`` region of the result, starting + with the ``[axis_offsets+1]`` element. + + It achieves this result by computing the DFT in the output array without + the need to zeropad. Much faster and memory efficient than the zero-padded + FFT approach if ``upsampled_region_size`` is much smaller than + ``data.size * upsample_factor``. + + Parameters + ---------- + data : 2D ndarray + The input data array (DFT of original data) to upsample. + upsampled_region_size : integer or tuple of integers, optional + The size of the region to be sampled. If one integer is provided, it + is duplicated up to the dimensionality of ``data``. If None, this is + equal to ``data.shape``. + upsample_factor : integer, optional + The upsampling factor. Defaults to 1. + axis_offsets : tuple of integers, optional + The offsets of the region to be sampled. Defaults to None (uses + image center) + + Returns + ------- + output : 2D ndarray + The upsampled DFT of the specified region. + """ + if upsampled_region_size is None: + upsampled_region_size = data.shape + # if people pass in an integer, expand it to a list of equal-sized sections + elif not hasattr(upsampled_region_size, "__iter__"): + upsampled_region_size = [upsampled_region_size, ] * data.ndim + else: + if len(upsampled_region_size) != data.ndim: + raise ValueError("shape of upsampled region sizes must be equal " + "to input data's number of dimensions.") + + if axis_offsets is None: + axis_offsets = [0, ] * data.ndim + elif not hasattr(axis_offsets, "__iter__"): + axis_offsets = [axis_offsets, ] * data.ndim + else: + if len(axis_offsets) != data.ndim: + raise ValueError("number of axis offsets must be equal to input " + "data's number of dimensions.") + + col_kernel = np.exp( + (-1j * 2 * np.pi / (data.shape[1] * upsample_factor)) * + (np.fft.ifftshift(np.arange(data.shape[1]))[:, None] - + np.floor(data.shape[1] / 2)).dot( + np.arange(upsampled_region_size[1])[None, :] - axis_offsets[1]) + ) + row_kernel = np.exp( + (-1j * 2 * np.pi / (data.shape[0] * upsample_factor)) * + (np.arange(upsampled_region_size[0])[:, None] - axis_offsets[0]).dot( + np.fft.ifftshift(np.arange(data.shape[0]))[None, :] - + np.floor(data.shape[0] / 2)) + ) + + return row_kernel.dot(data).dot(col_kernel) + + +def _compute_phasediff(cross_correlation_max): + """ + Compute global phase difference between the two images (should be + zero if images are non-negative). + + Parameters + ---------- + cross_correlation_max : complex + The complex value of the cross correlation at its maximum point. + """ + return np.arctan2(cross_correlation_max.imag, cross_correlation_max.real) + + +def _compute_error(cross_correlation_max, src_amp, target_amp): + """ + Compute RMS error metric between ``src_image`` and ``target_image``. + + Parameters + ---------- + cross_correlation_max : complex + The complex value of the cross correlation at its maximum point. + src_amp : float + The normalized average image intensity of the source image + target_amp : float + The normalized average image intensity of the target image + """ + error = 1.0 - cross_correlation_max * cross_correlation_max.conj() /\ + (src_amp * target_amp) + return np.sqrt(np.abs(error)) + + +def register_translation(src_image, target_image, upsample_factor=1, + space="real"): + """ + Efficient subpixel image translation registration by cross-correlation. + + This code gives the same precision as the FFT upsampled cross-correlation + in a fraction of the computation time and with reduced memory requirements. + It obtains an initial estimate of the cross-correlation peak by an FFT and + then refines the shift estimation by upsampling the DFT only in a small + neighborhood of that estimate by means of a matrix-multiply DFT. + + Parameters + ---------- + src_image : ndarray + Reference image. + target_image : ndarray + Image to register. Must be same dimensionality as ``src_image``. + upsample_factor : int, optional + Upsampling factor. Images will be registered to within + ``1 / upsample_factor`` of a pixel. For example + ``upsample_factor == 20`` means the images will be registered + within 1/20th of a pixel. Default is 1 (no upsampling) + space : string, one of ``real`` or ``fourier`` + Defines how the algorithm interprets input data. ``real`` means data + will be FFT'd to compute the correlation, while ``fourier`` data will + bypass both FFT and any windowing specified by ``filter_function``. + + Returns + ------- + shifts : ndarray + Shift vector (in pixels) required to register ``target_image`` with + ``src_image``. Axis ordering is consistent with numpy (e.g. Z, Y, X) + error : float + Translation invariant normalized RMS error between ``src_image`` and + ``target_image``. + phasediff : float + Global phase difference between the two images (should be + zero if images are non-negative). + + References + ---------- + .. [1] Manuel Guizar-Sicairos, Samuel T. Thurman, and James R. Fienup, + "Efficient subpixel image registration algorithms," + Optics Letters 33, 156-158 (2008). + """ + # images must be the same shape + if src_image.shape != target_image.shape: + raise ValueError("Error: images must be same size for " + "register_translation") + + # only 2D data makes sense right now + if src_image.ndim != 2 and upsample_factor > 1: + raise NotImplementedError("Error: register_translation only supports " + "subpixel registration for 2D images") + + # assume complex data is already in Fourier space + if space.lower() == 'fourier': + src_freq = src_image + target_freq = target_image + # real data needs to be fft'd. + elif space.lower() == 'real': + src_image = np.array(src_image, dtype=np.complex128, copy=False) + target_image = np.array(target_image, dtype=np.complex128, copy=False) + src_freq = np.fft.fftn(src_image) + target_freq = np.fft.fftn(target_image) + else: + raise ValueError("Error: register_translation only knows the \"real\" " + "and \"fourier\" values for the ``space`` argument.") + + # Whole-pixel shift - Compute cross-correlation by an IFFT + shape = src_freq.shape + image_product = src_freq * target_freq.conj() + cross_correlation = np.fft.fftshift(np.fft.ifftn(image_product)) + + # Locate maximum + maxima = np.unravel_index(np.argmax(cross_correlation), + cross_correlation.shape) + midpoints = np.array([np.fix(axis_size / 2) for axis_size in shape]) + + shifts = np.array(maxima, dtype=np.float64) + shifts -= midpoints + + if upsample_factor == 1: + src_amp = np.sum(np.abs(src_freq) ** 2) / src_freq.size + target_amp = np.sum(np.abs(target_freq) ** 2) / target_freq.size + CCmax = cross_correlation.max() + # If upsampling > 1, then refine estimate with matrix multiply DFT + else: + # Initial shift estimate in upsampled grid + shifts = np.round(shifts * upsample_factor) / upsample_factor + upsampled_region_size = np.ceil(upsample_factor * 1.5) + # Center of output array at dftshift + 1 + dftshift = np.fix(upsampled_region_size / 2.0) + midpoint_product = np.product(midpoints) + normalization = (midpoint_product * upsample_factor ** 2) + # Matrix multiply DFT around the current shift estimate + sample_region_offset = shifts*upsample_factor + dftshift + cross_correlation = _upsampled_dft(image_product, + upsampled_region_size, + upsample_factor, + sample_region_offset).conj() + cross_correlation /= normalization + # Locate maximum and map back to original pixel grid + maxima = np.array(np.unravel_index(np.argmax(cross_correlation), + cross_correlation.shape), + dtype=np.float64) + maxima -= dftshift + shifts = shifts - maxima / upsample_factor + CCmax = cross_correlation.max() + src_amp = _upsampled_dft(src_freq * src_freq.conj(), + 1, upsample_factor)[0, 0] + src_amp /= normalization + target_amp = _upsampled_dft(target_freq * target_freq.conj(), + 1, + upsample_factor)[0, 0] + target_amp /= normalization + + # If its only one row or column the shift along that dimension has no + # effect. We set to zero. + for dim in range(src_freq.ndim): + if midpoints[dim] == 1: + shifts[dim] = 0 + + return shifts, _compute_error(CCmax, src_amp, target_amp),\ + _compute_phasediff(CCmax) + + +# TODO: this is here for the sake of testing the registration functions. It is +# more accurate than scipy.ndimage.shift, which uses spline interpolation +# to achieve the same purpose. However, in its current state, this +# function is far more limited than scipy.ndimage.shift. Improvements +# include choices on how to handle boundary wrap-around, and expansion to +# n-dimensions. +def fourier_shift(image, shift): + """ + Shift a real-space 2D image by shift by applying shift to phase in Fourier + space. + + Parameters + ---------- + image : ndarray + Real-space 2D image to be shifted. + shift : length 2 array-like of floats + Shift to be applied to image. Order is row-major (y, x). + + Returns + ------- + out : ndarray + Shifted image. Boundaries wrap around. + """ + if image.ndim > 2: + raise NotImplementedError("Error: fourier_shift only supports " + " 2D images") + rows = np.fft.ifftshift(np.arange(-np.floor(image.shape[0] / 2), + np.ceil(image.shape[0] / 2))) + cols = np.fft.ifftshift(np.arange(-np.floor(image.shape[1] / 2), + np.ceil(image.shape[1] / 2))) + cols, rows = np.meshgrid(cols, rows) + out = np.fft.ifft2(np.fft.fft2(image) * np.exp(1j * 2 * np.pi * + (shift[0] * rows / image.shape[0] + + shift[1] * cols / image.shape[1]))) + return out diff --git a/skimage/feature/tests/test_register_translation.py b/skimage/feature/tests/test_register_translation.py new file mode 100644 index 00000000000..119aa37abb9 --- /dev/null +++ b/skimage/feature/tests/test_register_translation.py @@ -0,0 +1,63 @@ +import numpy as np +from numpy.testing import assert_allclose, assert_raises + +from skimage.feature.register_translation import register_translation,\ + fourier_shift +from skimage.data import camera + + +def test_correlation(): + image = camera() + shift = (-7, 12) + shifted_image = fourier_shift(image, shift) + + # pixel precision + result, error, diffphase = register_translation(image, shifted_image) + + assert_allclose(result[:2], np.array(shift)) + + +def test_subpixel_precision(): + reference_image = camera() + subpixel_shift = (-2.4, 1.32) + shifted_image = fourier_shift(reference_image, subpixel_shift) + + # subpixel precision + result, error, diffphase = register_translation(reference_image, + shifted_image, 100) + + assert_allclose(result[:2], np.array(subpixel_shift), atol=0.05) + + +def test_3d_input(): + # TODO: this test case is waiting on a Phantom data set to be added to the + # data module. + # pixel precision + # result, error, diffphase = register_translation(ref_image, shifted_image) + + # assert_allclose(np.array(result[:2]), np.array(shift)) + pass + + +def test_wrong_input(): + # Dimensionality mismatch + image = np.ones((5, 5, 1)) + template = np.ones((5, 5)) + assert_raises(ValueError, register_translation, template, image) + + # Greater than 2 dimensions does not support subpixel precision + # (TODO: should support 3D at some point.) + image = np.ones((5, 5, 5)) + template = np.ones((5, 5, 5)) + assert_raises(NotImplementedError, register_translation, + template, image, 2) + + # Size mismatch + image = np.ones((5, 5)) + template = np.ones((4, 4)) + assert_raises(ValueError, register_translation, template, image) + + +if __name__ == "__main__": + from numpy import testing + testing.run_module_suite()