Skip to content

Commit

Permalink
BUG: ndimage: Merge PR #240 from 'thouis/ticket-1491'
Browse files Browse the repository at this point in the history
Fix ndimage.label ignoring its output keyword.
  • Loading branch information
pv committed Jun 9, 2012
2 parents 8bb894b + 960e3c4 commit 5f0b796
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
19 changes: 18 additions & 1 deletion scipy/ndimage/measurements.py
Expand Up @@ -148,16 +148,33 @@ def label(input, structure=None, output=None):
raise RuntimeError('structure dimensions must be equal to 3') raise RuntimeError('structure dimensions must be equal to 3')
if not structure.flags.contiguous: if not structure.flags.contiguous:
structure = structure.copy() structure = structure.copy()
requested_output = None
requested_dtype = None
if isinstance(output, numpy.ndarray): if isinstance(output, numpy.ndarray):
if output.dtype.type != numpy.int32: if output.dtype.type != numpy.int32:
raise RuntimeError('output type must be int32') if output.shape != input.shape:
raise RuntimeError("output shape not correct")
# _ndimage.label() needs np.int32
requested_output = output
output = numpy.int32
else:
# output will be written directly
pass
else: else:
requested_dtype = output
output = numpy.int32 output = numpy.int32
output, return_value = _ni_support._get_output(output, input) output, return_value = _ni_support._get_output(output, input)
max_label = _nd_image.label(input, structure, output) max_label = _nd_image.label(input, structure, output)
if return_value is None: if return_value is None:
# result was written in-place
return max_label
elif requested_output is not None:
# original output was not int32
requested_output[...] = output[...]
return max_label return max_label
else: else:
if requested_dtype is not None:
return_value = return_value.astype(requested_dtype)
return return_value, max_label return return_value, max_label


def find_objects(input, max_label=0): def find_objects(input, max_label=0):
Expand Down
26 changes: 25 additions & 1 deletion scipy/ndimage/tests/test_measurements.py
@@ -1,6 +1,6 @@
from numpy.testing import assert_, assert_array_almost_equal, assert_equal, \ from numpy.testing import assert_, assert_array_almost_equal, assert_equal, \
assert_almost_equal, assert_array_equal, \ assert_almost_equal, assert_array_equal, \
run_module_suite, TestCase assert_raises, run_module_suite, TestCase
import numpy as np import numpy as np


import scipy.ndimage as ndimage import scipy.ndimage as ndimage
Expand Down Expand Up @@ -261,6 +261,30 @@ def test_label13():
assert_array_almost_equal(out, expected) assert_array_almost_equal(out, expected)
assert_equal(n, 1) assert_equal(n, 1)


def test_label_output_typed():
"test label with specified output with type"
data = np.ones([5])
for t in types:
output = np.zeros([5], dtype=t)
n = ndimage.label(data, output=output)
assert_array_almost_equal(output, 1)
assert_equal(n, 1)

def test_label_output_dtype():
"test label with specified output dtype"
data = np.ones([5])
for t in types:
output, n = ndimage.label(data, output=t)
assert_array_almost_equal(output, 1)
assert output.dtype == t

def test_label_output_wrong_size():
"test label with output of wrong size"
data = np.ones([5])
for t in types:
output = np.zeros([10], t)
assert_raises(RuntimeError, ndimage.label, data, output=output)

def test_find_objects01(): def test_find_objects01():
"find_objects 1" "find_objects 1"
data = np.ones([], dtype=int) data = np.ones([], dtype=int)
Expand Down

0 comments on commit 5f0b796

Please sign in to comment.