Skip to content

Commit

Permalink
improve multichannel handling and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
grlee77 committed May 19, 2015
1 parent 0956107 commit e6acdad
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 20 deletions.
13 changes: 8 additions & 5 deletions skimage/transform/_warps.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def resize(image, output_shape, order=1, mode='constant', cval=0, clip=True,
preserve_range=False, multichannel=True):
preserve_range=False):
"""Resize image to match a certain size.
Performs interpolation to up-size or down-size images. For down-sampling
Expand Down Expand Up @@ -48,8 +48,6 @@ def resize(image, output_shape, order=1, mode='constant', cval=0, clip=True,
preserve_range : bool, optional
Whether to keep the original range of values. Otherwise, the input
image is converted according to the conventions of `img_as_float`.
multichannel : bool, optional
If True and ``image.ndim > 2``, treat last axis as channels.
Examples
--------
Expand All @@ -60,7 +58,6 @@ def resize(image, output_shape, order=1, mode='constant', cval=0, clip=True,
(100, 100)
"""

output_shape = np.asarray(output_shape)
orig_shape = image.shape
if len(orig_shape) < len(output_shape):
Expand Down Expand Up @@ -120,7 +117,7 @@ def resize(image, output_shape, order=1, mode='constant', cval=0, clip=True,


def rescale(image, scale, order=1, mode='constant', cval=0, clip=True,
preserve_range=False):
preserve_range=False, multichannel=None):
"""Scale image by a certain factor.
Performs interpolation to upscale or down-scale images. For down-sampling
Expand Down Expand Up @@ -159,6 +156,8 @@ def rescale(image, scale, order=1, mode='constant', cval=0, clip=True,
preserve_range : bool, optional
Whether to keep the original range of values. Otherwise, the input
image is converted according to the conventions of `img_as_float`.
multichannel : bool, optional
If True, last axis will not be rescaled.
Examples
--------
Expand All @@ -171,11 +170,15 @@ def rescale(image, scale, order=1, mode='constant', cval=0, clip=True,
(256, 256)
"""
if multichannel is None:
multichannel = False # maintain previous default behavior
scale = np.atleast_1d(scale)
if len(scale) > 1 and len(scale) != image.ndim:
raise ValueError("must supply a single scale or one value per axis.")
orig_shape = np.asarray(image.shape)
output_shape = np.round(scale * orig_shape).astype('i8')
if multichannel: # don't scale channel dimension
output_shape[-1] = orig_shape[-1]

return resize(image, output_shape, order=order, mode=mode, cval=cval,
clip=clip, preserve_range=preserve_range)
Expand Down
39 changes: 24 additions & 15 deletions skimage/transform/pyramids.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,20 @@
from ..util import img_as_float


def _smooth(image, sigma, mode, cval, multichannel=True):
"""Return image with each channel smoothed by the Gaussian filter."""
def _previous_multichannel_default(multichannel, ndim):
if multichannel is None:
# needed to maintain previous default behavior
if ndim == 3:
return True
else:
return False
else:
return multichannel


def _smooth(image, sigma, mode, cval, multichannel=None):
"""Return image with each channel smoothed by the Gaussian filter."""
multichannel = _previous_multichannel_default(multichannel, image.ndim)
smoothed = np.empty(image.shape, dtype=np.double)

# apply Gaussian filter to all channels independently
Expand All @@ -29,7 +40,7 @@ def _check_factor(factor):


def pyramid_reduce(image, downscale=2, sigma=None, order=1,
mode='reflect', cval=0, multichannel=True):
mode='reflect', cval=0, multichannel=None):
"""Smooth and then downsample image.
Parameters
Expand Down Expand Up @@ -63,13 +74,13 @@ def pyramid_reduce(image, downscale=2, sigma=None, order=1,
.. [1] http://web.mit.edu/persci/people/adelson/pub_pdfs/pyramid83.pdf
"""

multichannel = _previous_multichannel_default(multichannel, image.ndim)
_check_factor(downscale)

image = img_as_float(image)

out_shape = tuple([math.ceil(d / float(downscale)) for d in image.shape])
if multichannel and image.ndim > 2:
if multichannel:
out_shape = out_shape[:-1]

if sigma is None:
Expand All @@ -83,7 +94,7 @@ def pyramid_reduce(image, downscale=2, sigma=None, order=1,


def pyramid_expand(image, upscale=2, sigma=None, order=1,
mode='reflect', cval=0, multichannel=True):
mode='reflect', cval=0, multichannel=None):
"""Upsample and then smooth image.
Parameters
Expand Down Expand Up @@ -117,13 +128,13 @@ def pyramid_expand(image, upscale=2, sigma=None, order=1,
.. [1] http://web.mit.edu/persci/people/adelson/pub_pdfs/pyramid83.pdf
"""

multichannel = _previous_multichannel_default(multichannel, image.ndim)
_check_factor(upscale)

image = img_as_float(image)

out_shape = tuple([math.ceil(upscale * d) for d in image.shape])
if multichannel and image.ndim > 2:
if multichannel:
out_shape = out_shape[:-1]

if sigma is None:
Expand All @@ -138,7 +149,7 @@ def pyramid_expand(image, upscale=2, sigma=None, order=1,


def pyramid_gaussian(image, max_layer=-1, downscale=2, sigma=None, order=1,
mode='reflect', cval=0, multichannel=True):
mode='reflect', cval=0, multichannel=None):
"""Yield images of the Gaussian pyramid formed by the input image.
Recursively applies the `pyramid_reduce` function to the image, and yields
Expand Down Expand Up @@ -183,7 +194,6 @@ def pyramid_gaussian(image, max_layer=-1, downscale=2, sigma=None, order=1,
.. [1] http://web.mit.edu/persci/people/adelson/pub_pdfs/pyramid83.pdf
"""

_check_factor(downscale)

# cast to float for consistent data type in pyramid
Expand Down Expand Up @@ -215,7 +225,7 @@ def pyramid_gaussian(image, max_layer=-1, downscale=2, sigma=None, order=1,


def pyramid_laplacian(image, max_layer=-1, downscale=2, sigma=None, order=1,
mode='reflect', cval=0, multichannel=True):
mode='reflect', cval=0, multichannel=None):
"""Yield images of the laplacian pyramid formed by the input image.
Each layer contains the difference between the downsampled and the
Expand Down Expand Up @@ -264,7 +274,7 @@ def pyramid_laplacian(image, max_layer=-1, downscale=2, sigma=None, order=1,
.. [2] http://sepwww.stanford.edu/~morgan/texturematch/paper_html/node3.html
"""

multichannel = _previous_multichannel_default(multichannel, image.ndim)
_check_factor(downscale)

# cast to float for consistent data type in pyramid
Expand All @@ -288,12 +298,11 @@ def pyramid_laplacian(image, max_layer=-1, downscale=2, sigma=None, order=1,
out_shape = tuple(
[math.ceil(d / float(downscale)) for d in current_shape])

if multichannel and image.ndim > 2:
if multichannel:
out_shape = out_shape[:-1]

resized_image = resize(smoothed_image, out_shape,
order=order, mode=mode, cval=cval,
multichannel=multichannel)
order=order, mode=mode, cval=cval)
smoothed_image = _smooth(resized_image, sigma, mode, cval,
multichannel)

Expand Down
40 changes: 40 additions & 0 deletions skimage/transform/tests/test_warps.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,35 @@ def test_rescale():
assert_almost_equal(scaled, ref)


def test_rescale_multichannel():
# 1D + channels
x = np.zeros((8, 3), dtype=np.double)
scaled = rescale(x, 2, order=0, multichannel=True)
assert_equal(scaled.shape, (16, 3))
# 2D
scaled = rescale(x, 2, order=0, multichannel=False)
assert_equal(scaled.shape, (16, 6))
# multichannel defaults to False
scaled = rescale(x, 2, order=0)
assert_equal(scaled.shape, (16, 6))

# 2D + channels
x = np.zeros((8, 8, 3), dtype=np.double)
scaled = rescale(x, 2, order=0, multichannel=True)
assert_equal(scaled.shape, (16, 16, 3))
# 3D
scaled = rescale(x, 2, order=0, multichannel=False)
assert_equal(scaled.shape, (16, 16, 6))

# 3D + channels
x = np.zeros((8, 8, 8, 3), dtype=np.double)
scaled = rescale(x, 2, order=0, multichannel=True)
assert_equal(scaled.shape, (16, 16, 16, 3))
# 4D
scaled = rescale(x, 2, order=0, multichannel=False)
assert_equal(scaled.shape, (16, 16, 16, 6))


def test_resize2d():
x = np.zeros((5, 5), dtype=np.double)
x[1, 1] = 1
Expand Down Expand Up @@ -190,6 +219,17 @@ def test_resize2d_4d():
assert_almost_equal(resized, ref)


def test_resize_nd():
for dim in range(1, 6):
shape = 2 + np.arange(dim) * 2
x = np.ones(shape)
out_shape = np.asarray(shape) * 1.5
resized = resize(x, out_shape, order=0, mode='reflect')
expected_shape = 1.5 * shape
assert_equal(resized.shape, expected_shape)
assert np.all(resized == 1)


def test_resize3d_bilinear():
# bilinear 3rd dimension
x = np.zeros((5, 5, 2), dtype=np.double)
Expand Down

0 comments on commit e6acdad

Please sign in to comment.