diff --git a/pyslim/util.py b/pyslim/util.py index bd304559..01d1edc3 100644 --- a/pyslim/util.py +++ b/pyslim/util.py @@ -11,10 +11,13 @@ def unique_labels_by_group(group, label, minlength=0): In other words, if the result is ``x``, then ``x[j]`` is ``len(set(label[group == j])) == 1``. ''' - n = np.bincount(1 + group, minlength=minlength + 1)[1:] - x = np.bincount(1 + group, weights=label, minlength=minlength + 1)[1:] - x2 = np.bincount(1 + group, weights=label.astype("int64") ** 2, minlength=minlength + 1)[1:] - # (a * x)**2 = a * (a * x**2) - return np.logical_and(n > 0, x**2 == n * x2) - - + w = label.astype("float64") + n = np.bincount(1 + group, minlength=minlength + 1) + x = np.bincount(1 + group, weights=w, minlength=minlength + 1) + with np.errstate(divide='ignore', invalid='ignore'): + xm = x/n + xm[n == 0] = 0 + w -= xm[1 + group] + gw = np.bincount(1 + group, weights=np.abs(w), minlength=minlength + 1)[1:] + # after subtracting groupwise means, should be all zero + return np.logical_and(n[1:] > 0, np.abs(gw) < 1e-7) diff --git a/tests/test_util.py b/tests/test_util.py index 6445e43b..d73f7f43 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -16,9 +16,10 @@ def verify_unique_labels_by_group(self, group, label, minlength): x = pyslim.util.unique_labels_by_group(group, label, minlength) self.assertGreaterEqual(len(x), minlength) for g in range(len(x)): - if g != -1: - u = set(label[group == g]) - self.assertEqual(len(u) == 1, x[g]) + u = set(label[group == g]) + if (len(u) == 1) != x[g]: + print(g, u, x[g], label[group == g], label.dtype) + self.assertEqual(len(u) == 1, x[g]) def test_all_same(self): n = 10 @@ -50,12 +51,19 @@ def test_all_unique(self): self.assertTrue(np.all(x)) def test_unique_labels_by_group(self): + np.random.seed(23) for ng in 3 * np.arange(2, 15): for n in (10, 100): for nl in (2, ng): for minl in (-5, 10000000): group = np.random.choice(np.arange(ng) - 1, size=n) + # "integer" labels label = minl + np.random.choice(np.arange(nl), size=n) self.verify_unique_labels_by_group(group, label, ng) + # int32 labels + self.verify_unique_labels_by_group(group, label.astype("int32"), ng) + # and float labels + label = minl + np.random.choice(np.random.uniform(0, 1, nl), size=n) + self.verify_unique_labels_by_group(group, label, ng)