Skip to content

Commit

Permalink
Merge pull request #160 from zopefoundation/issue24
Browse files Browse the repository at this point in the history
Make multiunion(), union(), difference(), and intersection() work with arbitrary iterables
  • Loading branch information
jamadden committed Apr 7, 2021
2 parents 49776df + 32630b7 commit e1da7bf
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 39 deletions.
8 changes: 8 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
4.8.0 (unreleased)
==================

- Make the ``multiunion``, ``union``, ``intersection``, and
``difference`` functions accept arbitrary Python iterables (that
iterate across the correct types). Previously, the Python
implementation allowed this, but the C implementation only allowed
objects (like ``TreeSet`` or ``Bucket``) defined in the same module
providing the function. See `issue 24
<https://github.com/zopefoundation/BTrees/issues/24>`_.

- Fix persistency bug in the Python version
(`#118 <https://github.com/zopefoundation/BTrees/issues/118>`_).

Expand Down
24 changes: 24 additions & 0 deletions src/BTrees/Interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,12 @@ def difference(c1, c2):
If neither c1 nor c2 is None, the output is a Set if c1 is a Set or
TreeSet, and is a Bucket if c1 is a Bucket or BTree.
While *c1* must be one of those types, *c2* can be any Python iterable
returning the correct types of objects.
.. versionchanged:: 4.8.0
Add support for *c2* to be an arbitrary iterable.
"""

def union(c1, c2):
Expand All @@ -338,6 +344,12 @@ def union(c1, c2):
The output is a Set containing keys from the input
collections.
*c1* and *c2* can be any Python iterables returning the
correct type of objects.
.. versionchanged:: 4.8.0
Add support for arbitrary iterables.
"""

def intersection(c1, c2):
Expand All @@ -348,6 +360,12 @@ def intersection(c1, c2):
The output is a Set containing matching keys from the input
collections.
*c1* and *c2* can be any Python iterables returning the
correct type of objects.
.. versionchanged:: 4.8.0
Add support for arbitrary iterables.
"""


Expand Down Expand Up @@ -476,6 +494,9 @@ def multiunion(seq):
:class:`BTrees.IOBTree.IOBTree` for :meth:`BTrees.IOBTree.multiunion`). The keys of the
mapping are added to the union.
+ Any iterable Python object that iterates across integers. This
will be slower than the above types.
The union is returned as a Set from the same module (for example,
:meth:`BTrees.IIBTree.multiunion` returns an :class:`BTrees.IIBTree.IISet`).
Expand All @@ -484,6 +505,9 @@ def multiunion(seq):
all the integers in all the inputs are sorted via a single
linear-time radix sort, then duplicates are removed in a second
linear-time pass.
.. versionchanged:: 4.8.0
Add support for arbitrary iterables of integers.
"""

class IBTreeFamily(Interface):
Expand Down
49 changes: 49 additions & 0 deletions src/BTrees/SetOpTemplate.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,45 @@ nextKeyAsSet(SetIteration *i)
}
#endif

static int nextGenericKeyIter(SetIteration* i)
{
PyObject* next = NULL;
int copied = 1;

if (i->position < 0)
{
/* Already finished. Do nothing. */
return 0;
}

if (i->position)
{
/* If we've been called before, release the key cache. */
DECREF_KEY(i->key);
}

i->position += 1;
next = PyIter_Next(i->set);
if (next == NULL)
{
/* Either an error, or the end of iteration. */
if (!PyErr_Occurred())
{
/* End of iteration. */
i->position = -1;
return 0;
}
/* Propagate the error. */
return -1;
}

COPY_KEY_FROM_ARG(i->key, next, copied);
Py_DECREF(next);
UNLESS(copied) return -1;
INCREF_KEY(i->key);
return 0;
}

/* initSetIteration
*
* Start the set iteration protocol. See the comments at struct SetIteration.
Expand Down Expand Up @@ -128,6 +167,16 @@ initSetIteration(SetIteration *i, PyObject *s, int useValues)
i->next = nextKeyAsSet;
}
#endif
else if (!useValues)
{
/* If we don't need keys and values, we can just use an iterator. */
/* Error detection on types is moved to the next() call. */
/* This is slower, but very convenient. If it raises a TypeError, */
/* let that propagate. */
i->set = PyObject_GetIter(s); /* Return a new reference. */
UNLESS(i->set) return -1;
i->next = nextGenericKeyIter;
}
else
{
PyErr_SetString(PyExc_TypeError, "set operation: invalid argument, cannot iterate");
Expand Down
12 changes: 6 additions & 6 deletions src/BTrees/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,10 @@ def __sub__(self, other):
return difference(self.__class__, self, other)

def __or__(self, other):
return union(self.__class__, self, other)
return union(self._set_type, self, other)

def __and__(self, other):
return intersection(self.__class__, self, other)
return intersection(self._set_type, self, other)


class _SetIteration(object):
Expand Down Expand Up @@ -1168,10 +1168,10 @@ def __sub__(self, other):
return difference(self.__class__, self, other)

def __or__(self, other):
return union(self.__class__, self, other)
return union(self._set_type, self, other)

def __and__(self, other):
return intersection(self.__class__, self, other)
return intersection(self._set_type, self, other)


def _get_simple_btree_bucket_state(state):
Expand Down Expand Up @@ -1392,7 +1392,7 @@ def union(set_type, o1, o2):
return o1
i1 = _SetIteration(o1, False, 0)
i2 = _SetIteration(o2, False, 0)
result = o1._set_type()
result = set_type()
def copy(i):
result._keys.append(i.key)
while i1.active and i2.active:
Expand Down Expand Up @@ -1422,7 +1422,7 @@ def intersection(set_type, o1, o2):
return o1
i1 = _SetIteration(o1, False, 0)
i2 = _SetIteration(o2, False, 0)
result = o1._set_type()
result = set_type()
def copy(i):
result._keys.append(i.key)
while i1.active and i2.active:
Expand Down
129 changes: 100 additions & 29 deletions src/BTrees/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2283,7 +2283,7 @@ def result(keys=(), mapbuilder=mapbuilder):

# Subclasses have to set up:
# builders() - function returning functions to build inputs,
# each returned callable tkes an optional keys arg
# each returned callable takes an optional keys arg
# intersection, union, difference - set to the type-correct versions
class SetResult(object):
def setUp(self):
Expand Down Expand Up @@ -2324,18 +2324,18 @@ def _difference(self, x, y):
def testNone(self):
for op in self.union, self.intersection, self.difference:
C = op(None, None)
self.assertTrue(C is None)
self.assertIsNone(C)

for op in self.union, self.intersection, self.difference:
for A in self.As:
C = op(A, None)
self.assertTrue(C is A)
self.assertIs(C, A)

C = op(None, A)
if op == self.difference:
self.assertTrue(C is None)
self.assertIsNone(C, None)
else:
self.assertTrue(C is A)
self.assertIs(C, A)

def testEmptyUnion(self):
for A in self.As:
Expand Down Expand Up @@ -2378,33 +2378,38 @@ def testUnion(self):
inputs = self.As + self.Bs
for A in inputs:
for B in inputs:
C = self.union(A, B)
self.assertTrue(not hasattr(C, "values"))
self.assertEqual(list(C), self._union(A, B))
self.assertEqual(set(A) | set(B), set(A | B))
for convert in lambda x: x, list, tuple, set:
C = self.union(convert(A), convert(B))
self.assertTrue(not hasattr(C, "values"))
self.assertEqual(list(C), self._union(A, B))
self.assertEqual(set(A) | set(B), set(A | B))

def testIntersection(self):
inputs = self.As + self.Bs
for A in inputs:
for B in inputs:
C = self.intersection(A, B)
self.assertTrue(not hasattr(C, "values"))
self.assertEqual(list(C), self._intersection(A, B))
self.assertEqual(set(A) & set(B), set(A & B))
for convert in lambda x: x, list, tuple, set:
C = self.intersection(convert(A), convert(B))
self.assertTrue(not hasattr(C, "values"))
self.assertEqual(list(C), self._intersection(A, B))
self.assertEqual(set(A) & set(B), set(A & B))

def testDifference(self):
inputs = self.As + self.Bs
for A in inputs:
for B in inputs:
C = self.difference(A, B)
# Difference preserves LHS values.
self.assertEqual(hasattr(C, "values"), hasattr(A, "values"))
want = self._difference(A, B)
if hasattr(A, "values"):
self.assertEqual(list(C.items()), want)
else:
self.assertEqual(list(C), want)
self.assertEqual(set(A) - set(B), set(A - B))
for convert in lambda x: x, list, tuple, set:
# Difference is unlike the others: The first argument
# must be a BTree type, in both C and Python.
C = self.difference(A, convert(B))
# Difference preserves LHS values.
self.assertEqual(hasattr(C, "values"), hasattr(A, "values"))
want = self._difference(A, B)
if hasattr(A, "values"):
self.assertEqual(list(C.items()), want)
else:
self.assertEqual(list(C), want)
self.assertEqual(set(A) - set(B), set(A - B))

def testLargerInputs(self): # pylint:disable=too-many-locals
from BTrees.IIBTree import IISet # pylint:disable=no-name-in-module
Expand Down Expand Up @@ -2586,7 +2591,7 @@ def setUp(self):
def testEmpty(self):
self.assertEqual(len(self.multiunion([])), 0)

def testOne(self):
def _testOne(self, builder):
for sequence in (
[3],
list(range(20)),
Expand All @@ -2598,18 +2603,84 @@ def testOne(self):
seq2 = list(reversed(sequence[:]))
seqsorted = sorted(sequence[:])
for seq in seq1, seq2, seqsorted:
for builder in self.mkset, self.mktreeset:
input = builder(seq)
output = self.multiunion([input])
self.assertEqual(len(seq), len(output))
self.assertEqual(seqsorted, list(output))
input = builder(seq)
output = self.multiunion([input])
self.assertEqual(len(seq), len(output))
self.assertEqual(seqsorted, list(output))

def testOneBTSet(self):
self._testOne(self.mkset)

def testOneBTTreeSet(self):
self._testOne(self.mktreeset)

def testOneList(self):
self._testOne(list)

def testOneTuple(self):
self._testOne(tuple)

def testOneSet(self):
self._testOne(set)

def testOneGenerator(self):
def generator(seq):
for i in seq:
yield i

self._testOne(generator)

def testValuesIgnored(self):
for builder in self.mkbucket, self.mkbtree:
for builder in self.mkbucket, self.mkbtree, dict:
input = builder([(1, 2), (3, 4), (5, 6)])
output = self.multiunion([input])
self.assertEqual([1, 3, 5], list(output))

def testValuesIgnoredNonInteger(self):
# This only uses a dict because the bucket and tree can't
# hold non-integers.
i1 = {1: 'a', 2: 'b'}
i2 = {1: 'c', 3: 'd'}

output = self.multiunion((i1, i2))
self.assertEqual([1, 2, 3], list(output))

def testRangeInputs(self):
i1 = range(3)
i2 = range(7)

output = self.multiunion((i1, i2))
self.assertEqual([0, 1, 2, 3, 4, 5, 6], list(output))

def testNegativeKeys(self):
i1 = (-1, -2, -3)
i2 = (0, 1, 2)

if not self.SUPPORTS_NEGATIVE_KEYS:
with self.assertRaises(TypeError):
self.multiunion((i2, i1))
else:
output = self.multiunion((i2, i1))
self.assertEqual([-3, -2, -1, 0, 1, 2], list(output))

def testOneIterableWithBadKeys(self):
i1 = [1, 2, 3, 'a']
for kind in list, tuple:
with self.assertRaises(TypeError):
self.multiunion((kind(i1),))

def testBadIterable(self):
class MyException(Exception):
pass

def gen():
for i in range(3):
yield i
raise MyException

with self.assertRaises(MyException):
self.multiunion((gen(),))

def testBigInput(self):
N = 100000
if (_c_optimizations_ignored() or 'Py' in type(self).__name__) and not PYPY:
Expand Down
8 changes: 4 additions & 4 deletions src/BTrees/tests/test__base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2610,8 +2610,8 @@ def test_lhs_set_rhs_mapping(self):
def test_lhs_mapping_rhs_set(self):
lhs = self._makeMapping({'a': 13, 'b': 12, 'c': 11})
rhs = self._makeSet('a', 'd')
result = self._callFUT(lhs.__class__, lhs, rhs)
self.assertTrue(isinstance(result, _Set))
result = self._callFUT(lhs._set_type, lhs, rhs)
self.assertIsInstance(result, _Set)
self.assertEqual(list(result), ['a', 'b', 'c', 'd'])

def test_both_mappings_rhs_empty(self):
Expand Down Expand Up @@ -2662,8 +2662,8 @@ def test_lhs_set_rhs_mapping(self):
def test_lhs_mapping_rhs_set(self):
lhs = self._makeMapping({'a': 13, 'b': 12, 'c': 11})
rhs = self._makeSet('a', 'd')
result = self._callFUT(lhs.__class__, lhs, rhs)
self.assertTrue(isinstance(result, _Set))
result = self._callFUT(lhs._set_type, lhs, rhs)
self.assertIsInstance(result, _Set)
self.assertEqual(list(result), ['a'])

def test_both_mappings_rhs_empty(self):
Expand Down

0 comments on commit e1da7bf

Please sign in to comment.