Skip to content

Commit

Permalink
fixed rounding bug in unique_labels_by_group
Browse files Browse the repository at this point in the history
  • Loading branch information
petrelharp committed Jul 8, 2020
1 parent 1f28265 commit b254393
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
17 changes: 10 additions & 7 deletions pyslim/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ def unique_labels_by_group(group, label, minlength=0):
In other words, if the result is ``x``, then
``x[j]`` is ``len(set(label[group == j])) == 1``.
'''
n = np.bincount(1 + group, minlength=minlength + 1)[1:]
x = np.bincount(1 + group, weights=label, minlength=minlength + 1)[1:]
x2 = np.bincount(1 + group, weights=label.astype("int64") ** 2, minlength=minlength + 1)[1:]
# (a * x)**2 = a * (a * x**2)
return np.logical_and(n > 0, x**2 == n * x2)


w = label.astype("float64")
n = np.bincount(1 + group, minlength=minlength + 1)
x = np.bincount(1 + group, weights=w, minlength=minlength + 1)
with np.errstate(divide='ignore', invalid='ignore'):
xm = x/n
xm[n == 0] = 0
w -= xm[1 + group]
gw = np.bincount(1 + group, weights=np.abs(w), minlength=minlength + 1)[1:]
# after subtracting groupwise means, should be all zero
return np.logical_and(n[1:] > 0, np.abs(gw) < 1e-7)
14 changes: 11 additions & 3 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ def verify_unique_labels_by_group(self, group, label, minlength):
x = pyslim.util.unique_labels_by_group(group, label, minlength)
self.assertGreaterEqual(len(x), minlength)
for g in range(len(x)):
if g != -1:
u = set(label[group == g])
self.assertEqual(len(u) == 1, x[g])
u = set(label[group == g])
if (len(u) == 1) != x[g]:
print(g, u, x[g], label[group == g], label.dtype)
self.assertEqual(len(u) == 1, x[g])

def test_all_same(self):
n = 10
Expand Down Expand Up @@ -50,12 +51,19 @@ def test_all_unique(self):
self.assertTrue(np.all(x))

def test_unique_labels_by_group(self):
np.random.seed(23)
for ng in 3 * np.arange(2, 15):
for n in (10, 100):
for nl in (2, ng):
for minl in (-5, 10000000):
group = np.random.choice(np.arange(ng) - 1, size=n)
# "integer" labels
label = minl + np.random.choice(np.arange(nl), size=n)
self.verify_unique_labels_by_group(group, label, ng)
# int32 labels
self.verify_unique_labels_by_group(group, label.astype("int32"), ng)
# and float labels
label = minl + np.random.choice(np.random.uniform(0, 1, nl), size=n)
self.verify_unique_labels_by_group(group, label, ng)


0 comments on commit b254393

Please sign in to comment.