Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP

Loading…

BUG: ndimage: measurements._stats() did not handle n-d arrays (ticket 15... #113

Closed
wants to merge 1 commit into from

2 participants

Warren Weckesser Ralf Gommers
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
This page is out of date. Refresh to see the latest.
46 scipy/ndimage/measurements.py
View
@@ -380,7 +380,35 @@ def _safely_castable_to_int(dt):
return safe
def _stats(input, labels=None, index=None, centered=False):
- '''returns count, sum, and optionally (sum - centre)^2 by label'''
+ '''Count, sum, and optionally compute (sum - centre)^2 of input by label
+
+ Parameters
+ ----------
+ input : array_like, n-dimensional
+ The input data to be analyzed.
+ labels : array_like (n-dimensional) or None
+ The labels of the data in `input`. This array must be broadcast
+ compatible with `input`; typically it is the same shape as `input`.
+ If `labels` is None, all nonzero values in `input` are treated as
+ the single labeled group.
+ index: label, sequence of labels, or None
+ These are the labels of the groups for which the stats are computed.
+ If `index` is None, the stats are computed for the single group where
+ `labels` is greater than 0.
+ centered: bool
+ If True, the centered sum of squares for each labeled group is
+ also returned.
+
+ Return value
+ ------------
+ counts:
+ The number of elements in each labeled group.
+ sums:
+ The sums of the values in each labeled group.
+ sums_c:
+ The sums of mean-centered squares of the values in each labeled group.
+ This is only returned if `centered` is True.
+ '''
def single_group(vals):
if centered:
@@ -402,9 +430,13 @@ def single_group(vals):
return single_group(input[labels == index])
def _sum_centered(labels):
+ # `labels` is expected to be an ndarray with the same shape as `input`.
+ # It must contain the label indices (which are not necessarily the labels
+ # themselves).
means = sums / counts
centered_input = input - means[labels]
- bc = numpy.bincount(labels,
+ # bincount expects 1d inputs, so we ravel the arguments.
+ bc = numpy.bincount(labels.ravel(),
weights=(centered_input * \
centered_input.conjugate()).ravel())
return bc
@@ -414,11 +446,17 @@ def _sum_centered(labels):
if (not _safely_castable_to_int(labels.dtype) or
labels.min() < 0 or labels.max() > labels.size):
+ # Use numpy.unique to generate the label indices. `new_labels` will
+ # be 1-d, but it should be interpreted as the flattened n-d array of
+ # label indices.
unique_labels, new_labels = numpy.unique(labels, return_inverse=True)
counts = numpy.bincount(new_labels)
sums = numpy.bincount(new_labels, weights=input.ravel())
if centered:
- sums_c = _sum_centered(new_labels)
+ # Compute the sum of the mean-centered squares.
+ # We must reshape new_labels to the n-d shape of `input` before
+ # passing it _sum_centered.
+ sums_c = _sum_centered(new_labels.reshape(labels.shape))
idxs = numpy.searchsorted(unique_labels, index)
# make all of idxs valid
idxs[idxs >= unique_labels.size] = 0
@@ -429,7 +467,7 @@ def _sum_centered(labels):
counts = numpy.bincount(labels.ravel())
sums = numpy.bincount(labels.ravel(), weights=input.ravel())
if centered:
- sums_c = _sum_centered(labels.ravel())
+ sums_c = _sum_centered(labels)
# make sure all index values are valid
idxs = numpy.asanyarray(index, numpy.int).copy()
found = (idxs >= 0) & (idxs < counts.size)
82 scipy/ndimage/tests/test_measurements.py
View
@@ -20,9 +20,12 @@ def test_a(self):
x = [0,1,2,6]
labels = [0,0,1,1]
index = [0,1]
- counts, sums = ndimage.measurements._stats(x, labels=labels, index=index)
- assert_array_equal(counts, [2, 2])
- assert_array_equal(sums, [1.0, 8.0])
+ for shp in [(4,), (2,2)]:
+ x = np.array(x).reshape(shp)
+ labels = np.array(labels).reshape(shp)
+ counts, sums = ndimage.measurements._stats(x, labels=labels, index=index)
+ assert_array_equal(counts, [2, 2])
+ assert_array_equal(sums, [1.0, 8.0])
def test_b(self):
# Same data as test_a, but different labels. The label 9 exceeds the
@@ -30,39 +33,51 @@ def test_b(self):
x = [0,1,2,6]
labels = [0,0,9,9]
index = [0,9]
- counts, sums = ndimage.measurements._stats(x, labels=labels, index=index)
- assert_array_equal(counts, [2, 2])
- assert_array_equal(sums, [1.0, 8.0])
+ for shp in [(4,), (2,2)]:
+ x = np.array(x).reshape(shp)
+ labels = np.array(labels).reshape(shp)
+ counts, sums = ndimage.measurements._stats(x, labels=labels, index=index)
+ assert_array_equal(counts, [2, 2])
+ assert_array_equal(sums, [1.0, 8.0])
def test_a_centered(self):
x = [0,1,2,6]
labels = [0,0,1,1]
index = [0,1]
- counts, sums, centers = ndimage.measurements._stats(x, labels=labels,
- index=index, centered=True)
- assert_array_equal(counts, [2, 2])
- assert_array_equal(sums, [1.0, 8.0])
- assert_array_equal(centers, [0.5, 8.0])
+ for shp in [(4,), (2,2)]:
+ x = np.array(x).reshape(shp)
+ labels = np.array(labels).reshape(shp)
+ counts, sums, centers = ndimage.measurements._stats(x, labels=labels,
+ index=index, centered=True)
+ assert_array_equal(counts, [2, 2])
+ assert_array_equal(sums, [1.0, 8.0])
+ assert_array_equal(centers, [0.5, 8.0])
def test_b_centered(self):
x = [0,1,2,6]
labels = [0,0,9,9]
index = [0,9]
- counts, sums, centers = ndimage.measurements._stats(x, labels=labels,
- index=index, centered=True)
- assert_array_equal(counts, [2, 2])
- assert_array_equal(sums, [1.0, 8.0])
- assert_array_equal(centers, [0.5, 8.0])
+ for shp in [(4,), (2,2)]:
+ x = np.array(x).reshape(shp)
+ labels = np.array(labels).reshape(shp)
+ counts, sums, centers = ndimage.measurements._stats(x, labels=labels,
+ index=index, centered=True)
+ assert_array_equal(counts, [2, 2])
+ assert_array_equal(sums, [1.0, 8.0])
+ assert_array_equal(centers, [0.5, 8.0])
def test_nonint_labels(self):
x = [0,1,2,6]
labels = [0.0, 0.0, 9.0, 9.0]
index = [0.0, 9.0]
- counts, sums, centers = ndimage.measurements._stats(x, labels=labels,
- index=index, centered=True)
- assert_array_equal(counts, [2, 2])
- assert_array_equal(sums, [1.0, 8.0])
- assert_array_equal(centers, [0.5, 8.0])
+ for shp in [(4,), (2,2)]:
+ x = np.array(x).reshape(shp)
+ labels = np.array(labels).reshape(shp)
+ counts, sums, centers = ndimage.measurements._stats(x, labels=labels,
+ index=index, centered=True)
+ assert_array_equal(counts, [2, 2])
+ assert_array_equal(sums, [1.0, 8.0])
+ assert_array_equal(centers, [0.5, 8.0])
class Test_measurements_select(TestCase):
@@ -975,5 +990,30 @@ def test_histogram03():
assert_array_almost_equal(output[0], expected1)
assert_array_almost_equal(output[1], expected2)
+
+def test_stat_funcs_2d():
+ """Apply the stat funcs to a 2-d array."""
+ a = np.array([[5,6,0,0,0], [8,9,0,0,0], [0,0,0,3,5]])
+ lbl = np.array([[1,1,0,0,0], [1,1,0,0,0], [0,0,0,2,2]])
+
+ mean= ndimage.mean(a, labels=lbl, index=[1, 2])
+ assert_array_equal(mean, [7.0, 4.0])
+
+ var = ndimage.variance(a, labels=lbl, index=[1, 2])
+ assert_array_equal(var, [2.5, 1.0])
+
+ std = ndimage.standard_deviation(a, labels=lbl, index=[1, 2])
+ assert_array_almost_equal(std, np.sqrt([2.5, 1.0]))
+
+ med = ndimage.median(a, labels=lbl, index=[1, 2])
+ assert_array_equal(med, [7.0, 4.0])
+
+ min = ndimage.minimum(a, labels=lbl, index=[1, 2])
+ assert_array_equal(min, [5, 3])
+
+ max = ndimage.maximum(a, labels=lbl, index=[1, 2])
+ assert_array_equal(max, [9, 5])
+
+
if __name__ == "__main__":
run_module_suite()
Something went wrong with that request. Please try again.