/
slic_superpixels.py
449 lines (390 loc) · 16 KB
/
slic_superpixels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
import math
from collections.abc import Iterable
from warnings import warn
import numpy as np
from numpy import random
from scipy.cluster.vq import kmeans2
from scipy.spatial.distance import pdist, squareform
from .._shared import utils
from .._shared.filters import gaussian
from ..color import rgb2lab
from ..util import img_as_float, regular_grid
from ._slic import _enforce_label_connectivity_cython, _slic_cython
def _get_mask_centroids(mask, n_centroids, multichannel):
"""Find regularly spaced centroids on a mask.
Parameters
----------
mask : 3D ndarray
The mask within which the centroids must be positioned.
n_centroids : int
The number of centroids to be returned.
Returns
-------
centroids : 2D ndarray
The coordinates of the centroids with shape (n_centroids, 3).
steps : 1D ndarray
The approximate distance between two seeds in all dimensions.
"""
# Get tight ROI around the mask to optimize
coord = np.array(np.nonzero(mask), dtype=float).T
# Fix random seed to ensure repeatability
# Keep old-style RandomState here as expected results in tests depend on it
rng = random.RandomState(123)
# select n_centroids randomly distributed points from within the mask
idx_full = np.arange(len(coord), dtype=int)
idx = np.sort(rng.choice(idx_full, min(n_centroids, len(coord)), replace=False))
# To save time, when n_centroids << len(coords), use only a subset of the
# coordinates when calling k-means. Rather than the full set of coords,
# we will use a substantially larger subset than n_centroids. Here we
# somewhat arbitrarily choose dense_factor=10 to make the samples
# 10 times closer together along each axis than the n_centroids samples.
dense_factor = 10
ndim_spatial = mask.ndim - 1 if multichannel else mask.ndim
n_dense = int((dense_factor**ndim_spatial) * n_centroids)
if len(coord) > n_dense:
# subset of points to use for the k-means calculation
# (much denser than idx, but less than the full set)
idx_dense = np.sort(rng.choice(idx_full, n_dense, replace=False))
else:
idx_dense = Ellipsis
centroids, _ = kmeans2(coord[idx_dense], coord[idx], iter=5)
# Compute the minimum distance of each centroid to the others
dist = squareform(pdist(centroids))
np.fill_diagonal(dist, np.inf)
closest_pts = dist.argmin(-1)
steps = abs(centroids - centroids[closest_pts, :]).mean(0)
return centroids, steps
def _get_grid_centroids(image, n_centroids):
"""Find regularly spaced centroids on the image.
Parameters
----------
image : 2D, 3D or 4D ndarray
Input image, which can be 2D or 3D, and grayscale or
multichannel.
n_centroids : int
The (approximate) number of centroids to be returned.
Returns
-------
centroids : 2D ndarray
The coordinates of the centroids with shape (~n_centroids, 3).
steps : 1D ndarray
The approximate distance between two seeds in all dimensions.
"""
d, h, w = image.shape[:3]
grid_z, grid_y, grid_x = np.mgrid[:d, :h, :w]
slices = regular_grid(image.shape[:3], n_centroids)
centroids_z = grid_z[slices].ravel()[..., np.newaxis]
centroids_y = grid_y[slices].ravel()[..., np.newaxis]
centroids_x = grid_x[slices].ravel()[..., np.newaxis]
centroids = np.concatenate([centroids_z, centroids_y, centroids_x], axis=-1)
steps = np.asarray([float(s.step) if s.step is not None else 1.0 for s in slices])
return centroids, steps
@utils.channel_as_last_axis(multichannel_output=False)
def slic(
image,
n_segments=100,
compactness=10.0,
max_num_iter=10,
sigma=0,
spacing=None,
convert2lab=None,
enforce_connectivity=True,
min_size_factor=0.5,
max_size_factor=3,
slic_zero=False,
start_label=1,
mask=None,
*,
channel_axis=-1,
):
"""Segments image using k-means clustering in Color-(x,y,z) space.
Parameters
----------
image : (M, N[, P][, C]) ndarray
Input image. Can be 2D or 3D, and grayscale or multichannel
(see `channel_axis` parameter).
Input image must either be NaN-free or the NaN's must be masked out.
n_segments : int, optional
The (approximate) number of labels in the segmented output image.
compactness : float, optional
Balances color proximity and space proximity. Higher values give
more weight to space proximity, making superpixel shapes more
square/cubic. In SLICO mode, this is the initial compactness.
This parameter depends strongly on image contrast and on the
shapes of objects in the image. We recommend exploring possible
values on a log scale, e.g., 0.01, 0.1, 1, 10, 100, before
refining around a chosen value.
max_num_iter : int, optional
Maximum number of iterations of k-means.
sigma : float or array-like of floats, optional
Width of Gaussian smoothing kernel for pre-processing for each
dimension of the image. The same sigma is applied to each dimension in
case of a scalar value. Zero means no smoothing.
Note that `sigma` is automatically scaled if it is scalar and
if a manual voxel spacing is provided (see Notes section). If
sigma is array-like, its size must match ``image``'s number
of spatial dimensions.
spacing : array-like of floats, optional
The voxel spacing along each spatial dimension. By default,
`slic` assumes uniform spacing (same voxel resolution along
each spatial dimension).
This parameter controls the weights of the distances along the
spatial dimensions during k-means clustering.
convert2lab : bool, optional
Whether the input should be converted to Lab colorspace prior to
segmentation. The input image *must* be RGB. Highly recommended.
This option defaults to ``True`` when ``channel_axis` is not None *and*
``image.shape[-1] == 3``.
enforce_connectivity : bool, optional
Whether the generated segments are connected or not
min_size_factor : float, optional
Proportion of the minimum segment size to be removed with respect
to the supposed segment size ```depth*width*height/n_segments```
max_size_factor : float, optional
Proportion of the maximum connected segment size. A value of 3 works
in most of the cases.
slic_zero : bool, optional
Run SLIC-zero, the zero-parameter mode of SLIC. [2]_
start_label : int, optional
The labels' index start. Should be 0 or 1.
.. versionadded:: 0.17
``start_label`` was introduced in 0.17
mask : ndarray, optional
If provided, superpixels are computed only where mask is True,
and seed points are homogeneously distributed over the mask
using a k-means clustering strategy. Mask number of dimensions
must be equal to image number of spatial dimensions.
.. versionadded:: 0.17
``mask`` was introduced in 0.17
channel_axis : int or None, optional
If None, the image is assumed to be a grayscale (single channel) image.
Otherwise, this parameter indicates which axis of the array corresponds
to channels.
.. versionadded:: 0.19
``channel_axis`` was added in 0.19.
Returns
-------
labels : 2D or 3D array
Integer mask indicating segment labels.
Raises
------
ValueError
If ``convert2lab`` is set to ``True`` but the last array
dimension is not of length 3.
ValueError
If ``start_label`` is not 0 or 1.
ValueError
If ``image`` contains unmasked NaN values.
ValueError
If ``image`` contains unmasked infinite values.
ValueError
If ``image`` is 2D but ``channel_axis`` is -1 (the default).
Notes
-----
* If `sigma > 0`, the image is smoothed using a Gaussian kernel prior to
segmentation.
* If `sigma` is scalar and `spacing` is provided, the kernel width is
divided along each dimension by the spacing. For example, if ``sigma=1``
and ``spacing=[5, 1, 1]``, the effective `sigma` is ``[0.2, 1, 1]``. This
ensures sensible smoothing for anisotropic images.
* The image is rescaled to be in [0, 1] prior to processing (masked
values are ignored).
* Images of shape (M, N, 3) are interpreted as 2D RGB images by default. To
interpret them as 3D with the last dimension having length 3, use
`channel_axis=None`.
* `start_label` is introduced to handle the issue [4]_. Label indexing
starts at 1 by default.
References
----------
.. [1] Radhakrishna Achanta, Appu Shaji, Kevin Smith, Aurelien Lucchi,
Pascal Fua, and Sabine Süsstrunk, SLIC Superpixels Compared to
State-of-the-art Superpixel Methods, TPAMI, May 2012.
:DOI:`10.1109/TPAMI.2012.120`
.. [2] https://www.epfl.ch/labs/ivrl/research/slic-superpixels/#SLICO
.. [3] Irving, Benjamin. "maskSLIC: regional superpixel generation with
application to local pathology characterisation in medical images.",
2016, :arXiv:`1606.09518`
.. [4] https://github.com/scikit-image/scikit-image/issues/3722
Examples
--------
>>> from skimage.segmentation import slic
>>> from skimage.data import astronaut
>>> img = astronaut()
>>> segments = slic(img, n_segments=100, compactness=10)
Increasing the compactness parameter yields more square regions:
>>> segments = slic(img, n_segments=100, compactness=20)
"""
if image.ndim == 2 and channel_axis is not None:
raise ValueError(
f"channel_axis={channel_axis} indicates multichannel, which is not "
"supported for a two-dimensional image; use channel_axis=None if "
"the image is grayscale"
)
image = img_as_float(image)
float_dtype = utils._supported_float_type(image.dtype)
# copy=True so subsequent in-place operations do not modify the
# function input
image = image.astype(float_dtype, copy=True)
if mask is not None:
# Create masked_image to rescale while ignoring masked values
mask = np.ascontiguousarray(mask, dtype=bool)
if channel_axis is not None:
mask_ = np.expand_dims(mask, axis=channel_axis)
mask_ = np.broadcast_to(mask_, image.shape)
else:
mask_ = mask
image_values = image[mask_]
else:
image_values = image
# Rescale image to [0, 1] to make choice of compactness insensitive to
# input image scale.
imin = image_values.min()
imax = image_values.max()
if np.isnan(imin):
raise ValueError("unmasked NaN values in image are not supported")
if np.isinf(imin) or np.isinf(imax):
raise ValueError("unmasked infinite values in image are not supported")
image -= imin
if imax != imin:
image /= imax - imin
use_mask = mask is not None
dtype = image.dtype
is_2d = False
multichannel = channel_axis is not None
if image.ndim == 2:
# 2D grayscale image
image = image[np.newaxis, ..., np.newaxis]
is_2d = True
elif image.ndim == 3 and multichannel:
# Make 2D multichannel image 3D with depth = 1
image = image[np.newaxis, ...]
is_2d = True
elif image.ndim == 3 and not multichannel:
# Add channel as single last dimension
image = image[..., np.newaxis]
if multichannel and (convert2lab or convert2lab is None):
if image.shape[channel_axis] != 3 and convert2lab:
raise ValueError("Lab colorspace conversion requires a RGB image.")
elif image.shape[channel_axis] == 3:
image = rgb2lab(image)
if start_label not in [0, 1]:
raise ValueError("start_label should be 0 or 1.")
# initialize cluster centroids for desired number of segments
update_centroids = False
if use_mask:
mask = mask.view('uint8')
if mask.ndim == 2:
mask = np.ascontiguousarray(mask[np.newaxis, ...])
if mask.shape != image.shape[:3]:
raise ValueError("image and mask should have the same shape.")
centroids, steps = _get_mask_centroids(mask, n_segments, multichannel)
update_centroids = True
else:
centroids, steps = _get_grid_centroids(image, n_segments)
if spacing is None:
spacing = np.ones(3, dtype=dtype)
elif isinstance(spacing, Iterable):
spacing = np.asarray(spacing, dtype=dtype)
if is_2d:
if spacing.size != 2:
if spacing.size == 3:
warn(
"Input image is 2D: spacing number of "
"elements must be 2. In the future, a ValueError "
"will be raised.",
FutureWarning,
stacklevel=2,
)
else:
raise ValueError(
f"Input image is 2D, but spacing has "
f"{spacing.size} elements (expected 2)."
)
else:
spacing = np.insert(spacing, 0, 1)
elif spacing.size != 3:
raise ValueError(
f"Input image is 3D, but spacing has "
f"{spacing.size} elements (expected 3)."
)
spacing = np.ascontiguousarray(spacing, dtype=dtype)
else:
raise TypeError("spacing must be None or iterable.")
if np.isscalar(sigma):
sigma = np.array([sigma, sigma, sigma], dtype=dtype)
sigma /= spacing
elif isinstance(sigma, Iterable):
sigma = np.asarray(sigma, dtype=dtype)
if is_2d:
if sigma.size != 2:
if spacing.size == 3:
warn(
"Input image is 2D: sigma number of "
"elements must be 2. In the future, a ValueError "
"will be raised.",
FutureWarning,
stacklevel=2,
)
else:
raise ValueError(
f"Input image is 2D, but sigma has "
f"{sigma.size} elements (expected 2)."
)
else:
sigma = np.insert(sigma, 0, 0)
elif sigma.size != 3:
raise ValueError(
f"Input image is 3D, but sigma has "
f"{sigma.size} elements (expected 3)."
)
if (sigma > 0).any():
# add zero smoothing for channel dimension
sigma = list(sigma) + [0]
image = gaussian(image, sigma=sigma, mode='reflect')
n_centroids = centroids.shape[0]
segments = np.ascontiguousarray(
np.concatenate([centroids, np.zeros((n_centroids, image.shape[3]))], axis=-1),
dtype=dtype,
)
# Scaling of ratio in the same way as in the SLIC paper so the
# values have the same meaning
step = max(steps)
ratio = 1.0 / compactness
image = np.ascontiguousarray(image * ratio, dtype=dtype)
if update_centroids:
# Step 2 of the algorithm [3]_
_slic_cython(
image,
mask,
segments,
step,
max_num_iter,
spacing,
slic_zero,
ignore_color=True,
start_label=start_label,
)
labels = _slic_cython(
image,
mask,
segments,
step,
max_num_iter,
spacing,
slic_zero,
ignore_color=False,
start_label=start_label,
)
if enforce_connectivity:
if use_mask:
segment_size = mask.sum() / n_centroids
else:
segment_size = math.prod(image.shape[:3]) / n_centroids
min_size = int(min_size_factor * segment_size)
max_size = int(max_size_factor * segment_size)
labels = _enforce_label_connectivity_cython(
labels, min_size, max_size, start_label=start_label
)
if is_2d:
labels = labels[0]
return labels