Skip to content

Commit

Permalink
support use of channel_axis in functions wrapped by adapt_rgb
Browse files Browse the repository at this point in the history
support channel_axis for hsv_value and each_channel decorators
  • Loading branch information
grlee77 committed Jul 2, 2021
1 parent 3d82e2c commit 8063049
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 46 deletions.
28 changes: 10 additions & 18 deletions skimage/color/adapt_rgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ def is_rgb_like(image, channel_axis=-1):
for functions that don't accept volumes as input, since checking an image's
shape is fragile.
"""
if channel_axis is None:
return False
return (image.ndim == 3) and (image.shape[channel_axis] in (3, 4))


def adapt_rgb(apply_to_rgb, channel_axis=-1):
def adapt_rgb(apply_to_rgb):
"""Return decorator that adapts to RGB images to a gray-scale filter.
This function is only intended to be used for functions that don't accept
Expand All @@ -31,28 +33,14 @@ def adapt_rgb(apply_to_rgb, channel_axis=-1):
----------
apply_to_rgb : function
Function that returns a filtered image from an image-filter and RGB
image. This will only be called if the image is RGB-like. This function
must have an argument named `channel_axis` that specified which axis of
the image corresponds to channels.
image. This will only be called if the image is RGB-like.
"""
sig = inspect.signature(apply_to_rgb)
if 'channel_axis' not in sig.parameters:
if channel_axis == -1:
channel_kwarg = {}
else:
# only raise on channel_axis != -1 for backwards compatibility
raise ValueError(
"apply_to_rgb must take an argument named `channel_axis`"
)
else:
channel_kwarg = dict(channel_axis=channel_axis)

def decorator(image_filter):
@functools.wraps(image_filter)
def image_filter_adapted(image, *args, **kwargs):
channel_axis = kwargs.get('channel_axis', -1)
if is_rgb_like(image, channel_axis=channel_axis):
return apply_to_rgb(image_filter, image, *args,
**channel_kwarg, **kwargs)
return apply_to_rgb(image_filter, image, *args, **kwargs)
else:
return image_filter(image, *args, **kwargs)
return image_filter_adapted
Expand All @@ -70,6 +58,8 @@ def hsv_value(image_filter, image, *args, channel_axis=-1, **kwargs):
Function that filters a gray-scale image.
image : array
Input image. Note that RGBA images are treated as RGB.
channel_axis : int or None, optional
This parameter specifies which axis corresponds to `channels`.
"""
# Slice the first three channels so that we remove any alpha channels.
channel_axis = channel_axis % image.ndim
Expand All @@ -93,6 +83,8 @@ def each_channel(image_filter, image, *args, channel_axis=-1, **kwargs):
Function that filters a gray-scale image.
image : array
Input image.
channel_axis : int or None, optional
This parameter specifies which axis corresponds to `channels`.
"""
c_new = [image_filter(c, *args, **kwargs)
for c in np.moveaxis(image, source=channel_axis, destination=0)]
Expand Down
56 changes: 31 additions & 25 deletions skimage/color/tests/test_adapt_rgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ def smooth_each(image, sigma):
return filters.gaussian(image, sigma)


@adapt_rgb(each_channel, channel_axis=0)
def smooth_each_axis0(image, sigma):
# if the function has a channel_axis argument, it will be picked up by
# adapt_rgb as the axis to iterate over
@adapt_rgb(each_channel)
def smooth_each_axis(image, sigma, channel_axis=-1):
return filters.gaussian(image, sigma)


Expand All @@ -49,8 +51,8 @@ def smooth_hsv(image, sigma):
return filters.gaussian(image, sigma)


@adapt_rgb(hsv_value, channel_axis=0)
def smooth_hsv_axis0(image, sigma):
@adapt_rgb(hsv_value)
def smooth_hsv_axis(image, sigma, channel_axis=-1):
return filters.gaussian(image, sigma)


Expand Down Expand Up @@ -80,9 +82,11 @@ def test_each_channel_with_filter_argument():
assert_allclose(channel, smooth(COLOR_IMAGE[:, :, i]))


def test_each_channel_with_filter_argument_axis0():
color_img = np.moveaxis(COLOR_IMAGE, source=-1, destination=0)
filtered = smooth_each_axis0(color_img, SIGMA)
@pytest.mark.parametrize("channel_axis", [0, 1, 2, -1])
def test_each_channel_with_filter_and_axis_argument(channel_axis):
color_img = np.moveaxis(COLOR_IMAGE, source=-1, destination=channel_axis)
filtered = smooth_each_axis(color_img, SIGMA, channel_axis=channel_axis)
filtered = np.moveaxis(filtered, source=channel_axis, destination=0)
for i, channel in enumerate(filtered):
assert_allclose(channel, smooth(COLOR_IMAGE[:, :, i]))

Expand All @@ -104,11 +108,13 @@ def test_hsv_value_with_filter_argument():
assert_allclose(color.rgb2hsv(filtered)[:, :, 2], smooth(value))


def test_hsv_value_with_filter_argument_axis0():
color_img = np.moveaxis(COLOR_IMAGE, source=-1, destination=0)
filtered = smooth_hsv_axis0(color_img, SIGMA)
@pytest.mark.parametrize("channel_axis", [0, 1, 2, -1])
def test_hsv_value_with_filter_and_axis_argument(channel_axis):
color_img = np.moveaxis(COLOR_IMAGE, source=-1, destination=channel_axis)
filtered = smooth_hsv_axis(color_img, SIGMA, channel_axis=channel_axis)
filtered = np.moveaxis(filtered, source=channel_axis, destination=-1)
value = color.rgb2hsv(COLOR_IMAGE)[:, :, 2]
assert_allclose(color.rgb2hsv(filtered, channel_axis=0)[2], smooth(value))
assert_allclose(color.rgb2hsv(filtered)[..., 2], smooth(value))


def test_hsv_value_with_non_float_output():
Expand All @@ -122,20 +128,20 @@ def test_hsv_value_with_non_float_output():
assert_allclose(filtered_value, filters.sobel(value), rtol=1e-5, atol=1e-5)


def test_missing_channel_axis_param():
# def test_missing_channel_axis_param():

def _identity(image_filter, image):
return image
# def _identity(image_filter, image):
# return image

# when channel_axis != -1, the function passed to adapt_rgb must have a
# channel_axis argument
with pytest.raises(ValueError):
@adapt_rgb(_identity, channel_axis=0)
def identity(image):
return image
# # when channel_axis != -1, the function passed to adapt_rgb must have a
# # channel_axis argument
# with pytest.raises(ValueError):
# @adapt_rgb(_identity, channel_axis=0)
# def identity(image):
# return image

# default channel_axis=-1 doesn't raise an error
@adapt_rgb(_identity)
def identity(image):
return image
assert_array_equal(COLOR_IMAGE, identity(COLOR_IMAGE))
# # default channel_axis=-1 doesn't raise an error
# @adapt_rgb(_identity)
# def identity(image):
# return image
# assert_array_equal(COLOR_IMAGE, identity(COLOR_IMAGE))
8 changes: 7 additions & 1 deletion skimage/exposure/_adapthist.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"""
import numbers
import numpy as np

from .._shared.utils import channel_as_last_axis
from ..util import img_as_float, img_as_uint
from ..color.adapt_rgb import adapt_rgb, hsv_value
from ..exposure import rescale_intensity
Expand All @@ -25,7 +27,7 @@

@adapt_rgb(hsv_value)
def equalize_adapthist(image, kernel_size=None,
clip_limit=0.01, nbins=256):
clip_limit=0.01, nbins=256, channel_axis=-1):
"""Contrast Limited Adaptive Histogram Equalization (CLAHE).
An algorithm for local contrast enhancement, that uses histograms computed
Expand All @@ -47,6 +49,10 @@ def equalize_adapthist(image, kernel_size=None,
contrast).
nbins : int, optional
Number of gray bins for histogram ("data range").
channel_axis : int, optional
When the image is a 2D RGB or RGBA image, this parameter specifies
which axis corresponds to the color `channels`. Otherwise, this
parameter is ignored.
Returns
-------
Expand Down
9 changes: 7 additions & 2 deletions skimage/exposure/tests/test_exposure.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,15 +372,20 @@ def test_adapthist_grayscale():
assert_almost_equal(norm_brightness_err(img, adapted), 0.0529, 3)


def test_adapthist_color():
@pytest.mark.parametrize("channel_axis", [0, 1, -1])
def test_adapthist_color(channel_axis):
"""Test an RGB color uint16 image
"""
img = util.img_as_uint(data.astronaut())
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
hist, bin_centers = exposure.histogram(img)
assert len(w) > 0
adapted = exposure.equalize_adapthist(img, clip_limit=0.01)

_img = np.moveaxis(img, source=-1, destination=channel_axis)
adapted = exposure.equalize_adapthist(_img, clip_limit=0.01,
channel_axis=channel_axis)
adapted = np.moveaxis(adapted, source=channel_axis, destination=-1)

assert adapted.min() == 0
assert adapted.max() == 1.0
Expand Down

0 comments on commit 8063049

Please sign in to comment.