diff --git a/tests/testSetOps.py b/tests/testSetOps.py index 2997472..c490509 100644 --- a/tests/testSetOps.py +++ b/tests/testSetOps.py @@ -82,6 +82,52 @@ def testPairs(self): got = mass_weightedIntersection([(y, w2), (x, w1)]) self.assertEqual(expected, list(got.items())) + def testMany(self): + import random + N = 15 # number of IIBTrees to feed in + L = [] + commonkey = N * 1000 + allkeys = {commonkey: 1} + for i in range(N): + t = IIBTree() + t[commonkey] = i + for j in range(N-i): + key = i + j + allkeys[key] = 1 + t[key] = N*i + j + L.append((t, i+1)) + random.shuffle(L) + allkeys = allkeys.keys() + allkeys.sort() + + # Test the union. + expected = [] + for key in allkeys: + sum = 0 + for t, w in L: + if t.has_key(key): + sum += t[key] * w + expected.append((key, sum)) + # print 'union', expected + got = mass_weightedUnion(L) + self.assertEqual(expected, list(got.items())) + + # Test the intersection. + expected = [] + for key in allkeys: + sum = 0 + for t, w in L: + if t.has_key(key): + sum += t[key] * w + else: + break + else: + # We didn't break out of the loop so it's in the intersection. + expected.append((key, sum)) + # print 'intersection', expected + got = mass_weightedIntersection(L) + self.assertEqual(expected, list(got.items())) + def test_suite(): return makeSuite(TestSetOps)