diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java index a68c37942dee4..65edc0149b438 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java @@ -43,7 +43,9 @@ public abstract class KVStoreView implements Iterable { boolean ascending = true; String index = KVIndex.NATURAL_INDEX_NAME; Object first = null; + Object last = null; long skip = 0L; + long max = Long.MAX_VALUE; public KVStoreView(Class type) { this.type = type; @@ -74,7 +76,25 @@ public KVStoreView first(Object value) { } /** - * Skips a number of elements in the resulting iterator. + * Stops iteration at the given value of the chosen index. + */ + public KVStoreView last(Object value) { + this.last = value; + return this; + } + + /** + * Stops iteration after a number of elements has been retrieved. + */ + public KVStoreView max(long max) { + Preconditions.checkArgument(max > 0L, "max must be positive."); + this.max = max; + return this; + } + + /** + * Skips a number of elements at the start of iteration. Skipped elements are not accounted + * when using {@link #max(long)}. */ public KVStoreView skip(long n) { this.skip = n; diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java index f65152a9fc36a..73ca8afc9eb28 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java @@ -38,10 +38,12 @@ class LevelDBIterator implements KVStoreIterator { private final LevelDBTypeInfo.Index index; private final byte[] indexKeyPrefix; private final byte[] end; + private final long max; private boolean checkedNext; private T next; private boolean closed; + private long count; LevelDBIterator(LevelDB db, KVStoreView params) throws Exception { this.db = db; @@ -51,6 +53,7 @@ class LevelDBIterator implements KVStoreIterator { this.ti = db.getTypeInfo(type); this.index = ti.index(params.index); this.indexKeyPrefix = index.keyPrefix(); + this.max = params.max; byte[] firstKey; if (params.first != null) { @@ -66,14 +69,27 @@ class LevelDBIterator implements KVStoreIterator { } it.seek(firstKey); + byte[] end = null; if (ascending) { - this.end = index.end(); + end = params.last != null ? index.end(params.last) : index.end(); } else { - this.end = null; + if (params.last != null) { + end = index.start(params.last); + } if (it.hasNext()) { - it.next(); + // When descending, the caller may have set up the start of iteration at a non-existant + // entry that is guaranteed to be after the desired entry. For example, if you have a + // compound key (a, b) where b is a, integer, you may seek to the end of the elements that + // have the same "a" value by specifying Integer.MAX_VALUE for "b", and that value may not + // exist in the database. So need to check here whether the next value actually belongs to + // the set being returned by the iterator before advancing. + byte[] nextKey = it.peekNext().getKey(); + if (compare(nextKey, indexKeyPrefix) <= 0) { + it.next(); + } } } + this.end = end; if (params.skip > 0) { skip(params.skip); @@ -147,6 +163,10 @@ public synchronized void close() throws IOException { } private T loadNext() { + if (count >= max) { + return null; + } + try { while (true) { boolean hasNext = ascending ? it.hasNext() : it.hasPrev(); @@ -173,11 +193,16 @@ private T loadNext() { return null; } - // If there's a known end key and it's found, stop. - if (end != null && Arrays.equals(nextKey, end)) { - return null; + // If there's a known end key and iteration has gone past it, stop. + if (end != null) { + int comp = compare(nextKey, end) * (ascending ? 1 : -1); + if (comp > 0) { + return null; + } } + count++; + // Next element is part of the iteration, return it. if (index == null || index.isCopy()) { return db.serializer.deserialize(nextEntry.getValue(), type); @@ -228,4 +253,17 @@ private byte[] stitch(byte[]... comps) { return dest; } + private int compare(byte[] a, byte[] b) { + int diff = 0; + int minLen = Math.min(a.length, b.length); + for (int i = 0; i < minLen; i++) { + diff += (a[i] - b[i]); + if (diff != 0) { + return diff; + } + } + + return a.length - b.length; + } + } diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java index 88c7cc08984bb..6c4469e1ed5d0 100644 --- a/common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java @@ -152,111 +152,170 @@ public static void cleanup() throws Exception { @Test public void naturalIndex() throws Exception { - testIteration(NATURAL_ORDER, view(), null); + testIteration(NATURAL_ORDER, view(), null, null); } @Test public void refIndex() throws Exception { - testIteration(REF_INDEX_ORDER, view().index("id"), null); + testIteration(REF_INDEX_ORDER, view().index("id"), null, null); } @Test public void copyIndex() throws Exception { - testIteration(COPY_INDEX_ORDER, view().index("name"), null); + testIteration(COPY_INDEX_ORDER, view().index("name"), null, null); } @Test public void numericIndex() throws Exception { - testIteration(NUMERIC_INDEX_ORDER, view().index("int"), null); + testIteration(NUMERIC_INDEX_ORDER, view().index("int"), null, null); } @Test public void naturalIndexDescending() throws Exception { - testIteration(NATURAL_ORDER, view().reverse(), null); + testIteration(NATURAL_ORDER, view().reverse(), null, null); } @Test public void refIndexDescending() throws Exception { - testIteration(REF_INDEX_ORDER, view().index("id").reverse(), null); + testIteration(REF_INDEX_ORDER, view().index("id").reverse(), null, null); } @Test public void copyIndexDescending() throws Exception { - testIteration(COPY_INDEX_ORDER, view().index("name").reverse(), null); + testIteration(COPY_INDEX_ORDER, view().index("name").reverse(), null, null); } @Test public void numericIndexDescending() throws Exception { - testIteration(NUMERIC_INDEX_ORDER, view().index("int").reverse(), null); + testIteration(NUMERIC_INDEX_ORDER, view().index("int").reverse(), null, null); } @Test public void naturalIndexWithStart() throws Exception { - CustomType1 first = pickFirst(); - testIteration(NATURAL_ORDER, view().first(first.key), first); + CustomType1 first = pickLimit(); + testIteration(NATURAL_ORDER, view().first(first.key), first, null); } @Test public void refIndexWithStart() throws Exception { - CustomType1 first = pickFirst(); - testIteration(REF_INDEX_ORDER, view().index("id").first(first.id), first); + CustomType1 first = pickLimit(); + testIteration(REF_INDEX_ORDER, view().index("id").first(first.id), first, null); } @Test public void copyIndexWithStart() throws Exception { - CustomType1 first = pickFirst(); - testIteration(COPY_INDEX_ORDER, view().index("name").first(first.name), first); + CustomType1 first = pickLimit(); + testIteration(COPY_INDEX_ORDER, view().index("name").first(first.name), first, null); } @Test public void numericIndexWithStart() throws Exception { - CustomType1 first = pickFirst(); - testIteration(NUMERIC_INDEX_ORDER, view().index("int").first(first.num), first); + CustomType1 first = pickLimit(); + testIteration(NUMERIC_INDEX_ORDER, view().index("int").first(first.num), first, null); } @Test public void naturalIndexDescendingWithStart() throws Exception { - CustomType1 first = pickFirst(); - testIteration(NATURAL_ORDER, view().reverse().first(first.key), first); + CustomType1 first = pickLimit(); + testIteration(NATURAL_ORDER, view().reverse().first(first.key), first, null); } @Test public void refIndexDescendingWithStart() throws Exception { - CustomType1 first = pickFirst(); - testIteration(REF_INDEX_ORDER, view().reverse().index("id").first(first.id), first); + CustomType1 first = pickLimit(); + testIteration(REF_INDEX_ORDER, view().reverse().index("id").first(first.id), first, null); } @Test public void copyIndexDescendingWithStart() throws Exception { - CustomType1 first = pickFirst(); - testIteration(COPY_INDEX_ORDER, view().reverse().index("name").first(first.name), - first); + CustomType1 first = pickLimit(); + testIteration(COPY_INDEX_ORDER, view().reverse().index("name").first(first.name), first, null); } @Test public void numericIndexDescendingWithStart() throws Exception { - CustomType1 first = pickFirst(); - testIteration(NUMERIC_INDEX_ORDER, view().reverse().index("int").first(first.num), - first); + CustomType1 first = pickLimit(); + testIteration(NUMERIC_INDEX_ORDER, view().reverse().index("int").first(first.num), first, null); } @Test public void naturalIndexWithSkip() throws Exception { - testIteration(NATURAL_ORDER, view().skip(RND.nextInt(allEntries.size() / 2)), null); + testIteration(NATURAL_ORDER, view().skip(RND.nextInt(allEntries.size() / 2)), null, null); } @Test public void refIndexWithSkip() throws Exception { testIteration(REF_INDEX_ORDER, view().index("id").skip(RND.nextInt(allEntries.size() / 2)), - null); + null, null); } @Test public void copyIndexWithSkip() throws Exception { testIteration(COPY_INDEX_ORDER, view().index("name").skip(RND.nextInt(allEntries.size() / 2)), - null); + null, null); } + @Test + public void naturalIndexWithMax() throws Exception { + testIteration(NATURAL_ORDER, view().max(RND.nextInt(allEntries.size() / 2)), null, null); + } + + @Test + public void copyIndexWithMax() throws Exception { + testIteration(COPY_INDEX_ORDER, view().index("name").max(RND.nextInt(allEntries.size() / 2)), + null, null); + } + + @Test + public void naturalIndexWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(NATURAL_ORDER, view().last(last.key), null, last); + } + + @Test + public void refIndexWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(REF_INDEX_ORDER, view().index("id").last(last.id), null, last); + } + + @Test + public void copyIndexWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(COPY_INDEX_ORDER, view().index("name").last(last.name), null, last); + } + + @Test + public void numericIndexWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(NUMERIC_INDEX_ORDER, view().index("int").last(last.num), null, last); + } + + @Test + public void naturalIndexDescendingWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(NATURAL_ORDER, view().reverse().last(last.key), null, last); + } + + @Test + public void refIndexDescendingWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(REF_INDEX_ORDER, view().reverse().index("id").last(last.id), null, last); + } + + @Test + public void copyIndexDescendingWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(COPY_INDEX_ORDER, view().reverse().index("name").last(last.name), + null, last); + } + + @Test + public void numericIndexDescendingWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(NUMERIC_INDEX_ORDER, view().reverse().index("int").last(last.num), + null, last); + } + @Test public void testRefWithIntNaturalKey() throws Exception { LevelDBSuite.IntKeyType i = new LevelDBSuite.IntKeyType(); @@ -272,8 +331,8 @@ public void testRefWithIntNaturalKey() throws Exception { } } - private CustomType1 pickFirst() { - // Picks a first element that has clashes with other elements in the given index. + private CustomType1 pickLimit() { + // Picks an element that has clashes with other elements in the given index. return clashingEntries.get(RND.nextInt(clashingEntries.size())); } @@ -297,22 +356,32 @@ private > int compareWithFallback( private void testIteration( final BaseComparator order, final KVStoreView params, - final CustomType1 first) throws Exception { + final CustomType1 first, + final CustomType1 last) throws Exception { List indexOrder = sortBy(order.fallback()); if (!params.ascending) { indexOrder = Lists.reverse(indexOrder); } Iterable expected = indexOrder; + BaseComparator expectedOrder = params.ascending ? order : order.reverse(); + if (first != null) { - final BaseComparator expectedOrder = params.ascending ? order : order.reverse(); expected = Iterables.filter(expected, v -> expectedOrder.compare(first, v) <= 0); } + if (last != null) { + expected = Iterables.filter(expected, v -> expectedOrder.compare(v, last) <= 0); + } + if (params.skip > 0) { expected = Iterables.skip(expected, (int) params.skip); } + if (params.max != Long.MAX_VALUE) { + expected = Iterables.limit(expected, (int) params.max); + } + List actual = collect(params); compareLists(expected, actual); }