forked from scikit-image/scikit-image
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_histogram_matching.py
83 lines (63 loc) · 2.9 KB
/
test_histogram_matching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import numpy as np
from skimage.exposure import histogram_matching
from skimage import exposure
from skimage import data
from skimage._shared.testing import assert_array_almost_equal, \
assert_almost_equal
import pytest
@pytest.mark.parametrize('array, template, expected_array', [
(np.arange(10), np.arange(100), np.arange(9, 100, 10)),
(np.random.rand(4), np.ones(3), np.ones(4))
])
def test_match_array_values(array, template, expected_array):
# when
matched = histogram_matching._match_cumulative_cdf(array, template)
# then
assert_array_almost_equal(matched, expected_array)
class TestMatchHistogram:
image_rgb = data.chelsea()
template_rgb = data.astronaut()
@pytest.mark.parametrize('image, reference, multichannel', [
(image_rgb, template_rgb, True),
(image_rgb[:, :, 0], template_rgb[:, :, 0], False)
])
def test_match_histograms(self, image, reference, multichannel):
"""Assert that pdf of matched image is close to the reference's pdf for
all channels and all values of matched"""
# when
matched = exposure.match_histograms(image, reference,
multichannel=multichannel)
matched_pdf = self._calculate_image_empirical_pdf(matched)
reference_pdf = self._calculate_image_empirical_pdf(reference)
# then
for channel in range(len(matched_pdf)):
reference_values, reference_quantiles = reference_pdf[channel]
matched_values, matched_quantiles = matched_pdf[channel]
for i, matched_value in enumerate(matched_values):
closest_id = (
np.abs(reference_values - matched_value)
).argmin()
assert_almost_equal(matched_quantiles[i],
reference_quantiles[closest_id],
decimal=1)
@pytest.mark.parametrize('image, reference', [
(image_rgb, template_rgb[:, :, 0]),
(image_rgb[:, :, 0], template_rgb)
])
def test_raises_value_error_on_channels_mismatch(self, image, reference):
with pytest.raises(ValueError):
exposure.match_histograms(image, reference)
@classmethod
def _calculate_image_empirical_pdf(cls, image):
"""Helper function for calculating empirical probability density
function of a given image for all channels"""
if image.ndim > 2:
image = image.transpose(2, 0, 1)
channels = np.array(image, copy=False, ndmin=3)
channels_pdf = []
for channel in channels:
channel_values, counts = np.unique(channel, return_counts=True)
channel_quantiles = np.cumsum(counts).astype(np.float64)
channel_quantiles /= channel_quantiles[-1]
channels_pdf.append((channel_values, channel_quantiles))
return np.asarray(channels_pdf, dtype=object)