Permalink
Browse files

Add thresholding module with Otsu's method to calculate threshold.

  • Loading branch information...
1 parent a635c23 commit fbbe38765d4afa7de1126540c31150a3ba94f862 @tonysyu tonysyu committed Dec 9, 2011
@@ -0,0 +1 @@
+from .thresholding import otsu_threshold, binarize
@@ -0,0 +1,53 @@
+import numpy as np
+
+import skimage
+from skimage import data
+from skimage.thresholding import otsu_threshold, binarize
+
+
+class TestSimpleImage():
+ def setup(self):
+ self.image = np.array([[0, 0, 1, 3, 5],
+ [0, 1, 4, 3, 4],
+ [1, 2, 5, 4, 1],
+ [2, 4, 5, 2, 1],
+ [4, 5, 1, 0, 0]], dtype=int)
+
+ def test_otsu(self):
+ assert otsu_threshold(self.image) == 2
+
+ @np.testing.raises(NotImplementedError)
+ def test_otsu_raises_error(self):
+ image = self.image - 2
+ otsu_threshold(image)
+
+ def test_otsu_float_image(self):
+ image = np.float64(self.image)
+ assert 2 <= otsu_threshold(image) < 3
+
+ def test_binarize(self):
+ expected = np.array([[0, 0, 0, 1, 1],
+ [0, 0, 1, 1, 1],
+ [0, 0, 1, 1, 0],
+ [0, 1, 1, 0, 0],
+ [1, 1, 0, 0, 0]])
+ assert np.all(binarize(self.image) == expected)
+
+
+def test_otsu_camera_image():
+ assert otsu_threshold(data.camera()) == 87
+
+def test_otsu_coins_image():
+ assert otsu_threshold(data.coins()) == 107
+
+def test_otsu_coins_image_as_float():
+ coins = skimage.img_as_float(data.coins())
+ assert 0.41 < otsu_threshold(coins) < 0.42
+
+def test_otsu_lena_image():
+ assert otsu_threshold(data.lena()) == 141
+
+
+if __name__ == '__main__':
+ np.testing.run_module_suite()
+
@@ -0,0 +1,105 @@
+import numpy as np
+
+
+__all__ = ['otsu_threshold', 'binarize']
+
+
+def otsu_threshold(image, bins=256):
+ """Return threshold value based on Otsu's method.
+
+ Parameters
+ ----------
+ image : array
+ Input image.
+ bins : int
+ Number of bins used to calculate histogram. This value is ignored for
+ integer arrays.
+
+ Returns
+ -------
+ threshold : numeric
+ Threshold value. int or float depending on input image.
+
+ References
+ ----------
+ .. [1] Wikipedia, http://en.wikipedia.org/wiki/Otsu's_Method
+
+ """
+ hist, bin_centers = histogram(image, bins)
+ hist = hist.astype(float)
+
+ # class probabilities for all possible thresholds
+ weight1 = np.cumsum(hist)
+ weight2 = np.cumsum(hist[::-1])[::-1]
+ # class means for all possible thresholds
+ mean1 = np.cumsum(hist * bin_centers) / weight1
+ mean2 = (np.cumsum((hist * bin_centers)[::-1]) / weight2[::-1])[::-1]
+
+ # Clip ends to align class 1 and class 2 variables:
+ # The last value of `weight1`/`mean1` should pair with zero values in
+ # `weight2`/`mean2`, which do not exist.
+ variance12 = weight1[:-1] * weight2[1:] * (mean1[:-1] - mean2[1:])**2
+
+ idx = np.argmax(variance12)
+ threshold = bin_centers[:-1][idx]
+ return threshold
+
+
+_threshold_funcs = {'otsu': otsu_threshold}
+def binarize(image, method='otsu'):
+ """Return binary image using an automatic thresholding method.
+
+ Parameters
+ ----------
+ image : array
+ Input array.
+ method : {'otsu'}
+ Method used to calculate threshold value. Currently, only Otsu's method
+ is implemented.
+
+ Returns
+ -------
+ out : array
+ Thresholded image.
+ """
+ get_threshold = _threshold_funcs[method]
+ threshold = get_threshold(image)
+ return image > threshold
+
+
+def histogram(image, bins):
+ """Return histogram of image.
+
+ Unlike `numpy.histogram`, this function returns the centers of bins and
+ does not rebin integer arrays.
+
+ Parameters
+ ----------
+ image : array
+ Input image.
+ bins : int
+ Number of bins used to calculate histogram. This value is ignored for
+ integer arrays.
+
+ Returns
+ -------
+ hist : array
+ The values of the histogram.
+ bin_centers : array
+ The values at the center of the bins.
+ """
+ if np.issubdtype(image.dtype, np.integer):
+ if np.min(image) < 0:
+ msg = "Images with negative values not allowed"
+ raise NotImplementedError(msg)
+ hist = np.bincount(image.flat)
+ bin_centers = np.arange(len(hist))
+
+ # clip histogram to return only non-zero bins
+ idx = np.nonzero(hist)[0][0]
+ return hist[idx:], bin_centers[idx:]
+ else:
+ hist, bin_edges = np.histogram(image, bins=bins)
+ bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.
+ return hist, bin_centers
+

0 comments on commit fbbe387

Please sign in to comment.