Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

BUG: ndimage: Merge PR #240 from 'thouis/ticket-1491'

Fix ndimage.label ignoring its output keyword.
  • Loading branch information...
commit 5f0b79605e1acb56cb65adffbf7799538266d7b4 2 parents 8bb894b + 960e3c4
@pv pv authored
View
19 scipy/ndimage/measurements.py
@@ -148,16 +148,33 @@ def label(input, structure=None, output=None):
raise RuntimeError('structure dimensions must be equal to 3')
if not structure.flags.contiguous:
structure = structure.copy()
+ requested_output = None
+ requested_dtype = None
if isinstance(output, numpy.ndarray):
if output.dtype.type != numpy.int32:
- raise RuntimeError('output type must be int32')
+ if output.shape != input.shape:
+ raise RuntimeError("output shape not correct")
+ # _ndimage.label() needs np.int32
+ requested_output = output
+ output = numpy.int32
+ else:
+ # output will be written directly
+ pass
else:
+ requested_dtype = output
output = numpy.int32
output, return_value = _ni_support._get_output(output, input)
max_label = _nd_image.label(input, structure, output)
if return_value is None:
+ # result was written in-place
+ return max_label
+ elif requested_output is not None:
+ # original output was not int32
+ requested_output[...] = output[...]
return max_label
else:
+ if requested_dtype is not None:
+ return_value = return_value.astype(requested_dtype)
return return_value, max_label
def find_objects(input, max_label=0):
View
26 scipy/ndimage/tests/test_measurements.py
@@ -1,6 +1,6 @@
from numpy.testing import assert_, assert_array_almost_equal, assert_equal, \
assert_almost_equal, assert_array_equal, \
- run_module_suite, TestCase
+ assert_raises, run_module_suite, TestCase
import numpy as np
import scipy.ndimage as ndimage
@@ -261,6 +261,30 @@ def test_label13():
assert_array_almost_equal(out, expected)
assert_equal(n, 1)
+def test_label_output_typed():
+ "test label with specified output with type"
+ data = np.ones([5])
+ for t in types:
+ output = np.zeros([5], dtype=t)
+ n = ndimage.label(data, output=output)
+ assert_array_almost_equal(output, 1)
+ assert_equal(n, 1)
+
+def test_label_output_dtype():
+ "test label with specified output dtype"
+ data = np.ones([5])
+ for t in types:
+ output, n = ndimage.label(data, output=t)
+ assert_array_almost_equal(output, 1)
+ assert output.dtype == t
+
+def test_label_output_wrong_size():
+ "test label with output of wrong size"
+ data = np.ones([5])
+ for t in types:
+ output = np.zeros([10], t)
+ assert_raises(RuntimeError, ndimage.label, data, output=output)
+
def test_find_objects01():
"find_objects 1"
data = np.ones([], dtype=int)
Please sign in to comment.
Something went wrong with that request. Please try again.