Skip to content

Commit

Permalink
Iterate once to create two iterators in partition (#2577)
Browse files Browse the repository at this point in the history
* Reproduce the problem

* Iterate once to create two iterators in partition

* Avoid using io.vavr.collection.Stream

* Test behavior of `partition` on different classes

* Test that Stream.partition() is lazy

* Create Iterator.duplicate() and add tests

* Change the implementation of Iterator.partition()

* Fix Set

* Fix Map

* Fix Multimap

* Move duplicate to IteratorModule

* Remove synchronized keyword

* Remove hashCode and equals

* Avoid using isEqualTo

* Remove redundant tests
  • Loading branch information
mincong-h authored and danieldietrich committed Jul 14, 2021
1 parent bd2127a commit d8cf5bb
Show file tree
Hide file tree
Showing 20 changed files with 379 additions and 34 deletions.
8 changes: 6 additions & 2 deletions vavr/src/main/java/io/vavr/collection/AbstractMultimap.java
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,12 @@ public M orElse(Supplier<? extends Iterable<? extends Tuple2<K, V>>> supplier) {
@Override
public Tuple2<M, M> partition(Predicate<? super Tuple2<K, V>> predicate) {
Objects.requireNonNull(predicate, "predicate is null");
final Tuple2<Iterator<Tuple2<K, V>>, Iterator<Tuple2<K, V>>> p = iterator().partition(predicate);
return Tuple.of((M) createFromEntries(p._1), (M) createFromEntries(p._2));
final java.util.List<Tuple2<K, V>> left = new java.util.ArrayList<>();
final java.util.List<Tuple2<K, V>> right = new java.util.ArrayList<>();
for (Tuple2<K, V> entry : this) {
(predicate.test(entry) ? left : right).add(entry);
}
return Tuple.of((M) createFromEntries(left), (M) createFromEntries(right));
}

@SuppressWarnings("unchecked")
Expand Down
3 changes: 1 addition & 2 deletions vavr/src/main/java/io/vavr/collection/BitSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -815,8 +815,7 @@ public BitSet<T> scan(T zero, BiFunction<? super T, ? super T, ? extends T> oper

@Override
public Tuple2<BitSet<T>, BitSet<T>> partition(Predicate<? super T> predicate) {
Objects.requireNonNull(predicate, "predicate is null");
return iterator().partition(predicate).map(this::createFromAll, this::createFromAll);
return Collections.partition(this, this::createFromAll, predicate);
}

@Override
Expand Down
14 changes: 13 additions & 1 deletion vavr/src/main/java/io/vavr/collection/Collections.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package io.vavr.collection;

import io.vavr.Tuple;
import io.vavr.Tuple2;
import io.vavr.collection.JavaConverters.ChangePolicy;
import io.vavr.collection.JavaConverters.ListView;
import io.vavr.control.Option;
Expand Down Expand Up @@ -283,6 +284,17 @@ static <K, V, K2, U extends Map<K2, V>> U mapKeys(Map<K, V> source, U zero, Func
});
}

static <C extends Traversable<T>, T> Tuple2<C, C> partition(C collection, Function<Iterable<T>, C> creator,
Predicate<? super T> predicate) {
Objects.requireNonNull(predicate, "predicate is null");
final java.util.List<T> left = new java.util.ArrayList<>();
final java.util.List<T> right = new java.util.ArrayList<>();
for (T element : collection) {
(predicate.test(element) ? left : right).add(element);
}
return Tuple.of(creator.apply(left), creator.apply(right));
}

@SuppressWarnings("unchecked")
static <C extends Traversable<T>, T> C removeAll(C source, Iterable<? extends T> elements) {
Objects.requireNonNull(elements, "elements is null");
Expand Down Expand Up @@ -531,7 +543,7 @@ private static <T> IterableWithSize<T> withSizeTraversable(Iterable<? extends T>
return new IterableWithSize<>(iterable, ((Traversable<?>) iterable).size());
}
}

static class IterableWithSize<T> {
private final Iterable<? extends T> iterable;
private final int size;
Expand Down
6 changes: 2 additions & 4 deletions vavr/src/main/java/io/vavr/collection/HashSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ public boolean isAsync() {
public boolean isEmpty() {
return tree.isEmpty();
}

/**
* A {@code HashSet} is computed eagerly.
*
Expand Down Expand Up @@ -733,9 +733,7 @@ public HashSet<T> orElse(Supplier<? extends Iterable<? extends T>> supplier) {

@Override
public Tuple2<HashSet<T>, HashSet<T>> partition(Predicate<? super T> predicate) {
Objects.requireNonNull(predicate, "predicate is null");
final Tuple2<Iterator<T>, Iterator<T>> p = iterator().partition(predicate);
return Tuple.of(HashSet.ofAll(p._1), HashSet.ofAll(p._2));
return Collections.partition(this, HashSet::ofAll, predicate);
}

@Override
Expand Down
48 changes: 42 additions & 6 deletions vavr/src/main/java/io/vavr/collection/Iterator.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import java.math.BigDecimal;
import java.util.*;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.*;

import static java.lang.Double.NEGATIVE_INFINITY;
Expand Down Expand Up @@ -1694,10 +1695,8 @@ default Tuple2<Iterator<T>, Iterator<T>> partition(Predicate<? super T> predicat
if (!hasNext()) {
return Tuple.of(empty(), empty());
} else {
final Stream<T> that = Stream.ofAll(this);
final Iterator<T> first = that.iterator().filter(predicate);
final Iterator<T> second = that.iterator().filter(predicate.negate());
return Tuple.of(first, second);
final Tuple2<Iterator<T>, Iterator<T>> dup = IteratorModule.duplicate(this);
return Tuple.of(dup._1.filter(predicate), dup._2.filter(predicate.negate()));
}
}

Expand Down Expand Up @@ -1909,6 +1908,7 @@ default Tuple2<Iterator<T>, Iterator<T>> span(Predicate<? super T> predicate) {
}
}


@Override
default String stringPrefix() {
return "Iterator";
Expand Down Expand Up @@ -2042,6 +2042,44 @@ public T getNext() {

interface IteratorModule {

/**
* Creates two new iterators that both iterates over the same elements as
* this iterator and in the same order. The duplicate iterators are
* considered equal if they are positioned at the same element.
* <p>
* Given that most methods on iterators will make the original iterator
* unfit for further use, this methods provides a reliable way of calling
* multiple such methods on an iterator.
*
* @return a pair of iterators
*/
static <T> Tuple2<Iterator<T>, Iterator<T>> duplicate(Iterator<T> iterator) {
final java.util.Queue<T> gap = new java.util.LinkedList<>();
final AtomicReference<Iterator<T>> ahead = new AtomicReference<>();
class Partner implements Iterator<T> {

@Override
public boolean hasNext() {
return (this != ahead.get() && !gap.isEmpty()) || iterator.hasNext();
}

@Override
public T next() {
if (gap.isEmpty()) {
ahead.set(this);
}
if (this == ahead.get()) {
final T element = iterator.next();
gap.add(element);
return element;
} else {
return gap.poll();
}
}
}
return Tuple.of(new Partner(), new Partner());
}

// inspired by Scala's ConcatIterator
final class ConcatIterator<T> extends AbstractIterator<T> {

Expand All @@ -2064,10 +2102,8 @@ Cell<T> append(Iterator<T> it) {
}

private Iterator<T> curr;

private Cell<T> tail;
private Cell<T> last;

private boolean hasNextCalculated;

void append(java.util.Iterator<? extends T> that) {
Expand Down
4 changes: 1 addition & 3 deletions vavr/src/main/java/io/vavr/collection/LinkedHashSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -754,9 +754,7 @@ public LinkedHashSet<T> orElse(Supplier<? extends Iterable<? extends T>> supplie

@Override
public Tuple2<LinkedHashSet<T>, LinkedHashSet<T>> partition(Predicate<? super T> predicate) {
Objects.requireNonNull(predicate, "predicate is null");
final Tuple2<Iterator<T>, Iterator<T>> p = iterator().partition(predicate);
return Tuple.of(LinkedHashSet.ofAll(p._1), LinkedHashSet.ofAll(p._2));
return Collections.partition(this, LinkedHashSet::ofAll, predicate);
}

@Override
Expand Down
8 changes: 6 additions & 2 deletions vavr/src/main/java/io/vavr/collection/Maps.java
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,12 @@ static <T, K, V, M extends Map<K, V>> M ofStream(M map, java.util.stream.Stream<
static <K, V, M extends Map<K, V>> Tuple2<M, M> partition(M map, OfEntries<K, V, M> ofEntries,
Predicate<? super Tuple2<K, V>> predicate) {
Objects.requireNonNull(predicate, "predicate is null");
final Tuple2<Iterator<Tuple2<K, V>>, Iterator<Tuple2<K, V>>> p = map.iterator().partition(predicate);
return Tuple.of(ofEntries.apply(p._1), ofEntries.apply(p._2));
final java.util.List<Tuple2<K, V>> left = new java.util.ArrayList<>();
final java.util.List<Tuple2<K, V>> right = new java.util.ArrayList<>();
for (Tuple2<K, V> entry : map) {
(predicate.test(entry) ? left : right).add(entry);
}
return Tuple.of(ofEntries.apply(left), ofEntries.apply(right));
}

static <K, V, M extends Map<K, V>> M peek(M map, Consumer<? super Tuple2<K, V>> action) {
Expand Down
18 changes: 9 additions & 9 deletions vavr/src/main/java/io/vavr/collection/Traversable.java
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ public interface Traversable<T> extends Foldable<T>, Value<T> {
static <T> Traversable<T> narrow(Traversable<? extends T> traversable) {
return (Traversable<T>) traversable;
}

/**
* Matches each element with a unique key that you extract from it.
* If the same key is present twice, the function will return {@code None}.
Expand Down Expand Up @@ -227,7 +227,7 @@ default Option<Double> average() {
throw new UnsupportedOperationException("not numeric", x);
}
}

/**
* Collects all elements that are in the domain of the given {@code partialFunction} by mapping the elements to type {@code R}.
* <p>
Expand Down Expand Up @@ -345,7 +345,7 @@ default int count(Predicate<? super T> predicate) {
Traversable<T> dropUntil(Predicate<? super T> predicate);

/**
* Drops elements while the predicate holds for the current element.
* Drops elements while the predicate holds for the current element.
* <p>
* Note: This is essentially the same as {@code dropUntil(predicate.negate())}.
* It is intended to be used with method references, which cannot be negated directly.
Expand Down Expand Up @@ -374,7 +374,7 @@ default int count(Predicate<? super T> predicate) {
* <li>contain the same elements</li>
* <li>have the same element order, if the collections are of type Seq</li>
* </ul>
*
*
* Two Map/Multimap elements, resp. entries, (key1, value1) and (key2, value2) are equal,
* if the keys are equal and the values are equal.
* <p>
Expand Down Expand Up @@ -600,7 +600,7 @@ default T get() {
default Option<T> headOption() {
return isEmpty() ? Option.none() : Option.some(head());
}

/**
* Returns the hash code of this collection.
* <br>
Expand Down Expand Up @@ -965,7 +965,7 @@ default <U extends Comparable<? super U>> Option<T> minBy(Function<? super T, ?
return Option.some(tm);
}
}

/**
* Joins the elements of this by concatenating their string representations.
* <p>
Expand Down Expand Up @@ -1255,7 +1255,7 @@ default Option<T> reduceRightOption(BiFunction<? super T, ? super T, ? extends T
* @throws NullPointerException if {@code operation} is null.
*/
<U> Traversable<U> scanRight(U zero, BiFunction<? super T, ? super U, ? extends U> operation);

/**
* Returns the single element of this Traversable or throws, if this is empty or contains more than one element.
*
Expand Down Expand Up @@ -1418,7 +1418,7 @@ default Number sum() {
}
}
}

/**
* Drops the first element of a non-empty Traversable.
*
Expand Down Expand Up @@ -1548,7 +1548,7 @@ default Number sum() {
* @throws NullPointerException if {@code that} is null
*/
<U> Traversable<Tuple2<T, U>> zipAll(Iterable<? extends U> that, T thisElem, U thatElem);

/**
* Returns a traversable formed from this traversable and another Iterable collection by mapping elements.
* If one of the two iterables is longer than the other, its remaining elements are ignored.
Expand Down
4 changes: 1 addition & 3 deletions vavr/src/main/java/io/vavr/collection/TreeSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -801,9 +801,7 @@ public TreeSet<T> orElse(Supplier<? extends Iterable<? extends T>> supplier) {

@Override
public Tuple2<TreeSet<T>, TreeSet<T>> partition(Predicate<? super T> predicate) {
Objects.requireNonNull(predicate, "predicate is null");
return iterator().partition(predicate).map(i1 -> TreeSet.ofAll(tree.comparator(), i1),
i2 -> TreeSet.ofAll(tree.comparator(), i2));
return Collections.partition(this, values -> TreeSet.ofAll(tree.comparator(), values), predicate);
}

@Override
Expand Down
15 changes: 15 additions & 0 deletions vavr/src/test/java/io/vavr/collection/AbstractMapTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1440,6 +1440,21 @@ public void shouldReturnDefaultValue() {
assertThat(map.getOrElse("3", "3")).isEqualTo("3");
}

// -- partition

@Test
public void shouldPartitionInOneIteration() {
final AtomicInteger count = new AtomicInteger(0);
final Map<String, Integer> map = mapOf("1", 1, "2", 2, "3", 3);
final Tuple2<? extends Map<String, Integer>, ? extends Map<String, Integer>> results = map.partition(entry -> {
count.incrementAndGet();
return true;
});
assertThat(results._1).isEqualTo(mapOf("1", 1, "2", 2, "3", 3));
assertThat(results._2).isEmpty();
assertThat(count.get()).isEqualTo(3);
}

// -- spliterator

@Test
Expand Down
14 changes: 14 additions & 0 deletions vavr/src/test/java/io/vavr/collection/AbstractMultimapTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,20 @@ public void shouldPartitionIntsInOddAndEvenHavingOddAndEvenNumbers() {
mapOfTuples(Tuple.of(1, 2), Tuple.of(3, 4))));
}

@Test
@SuppressWarnings("unchecked")
public void shouldPartitionInOneIteration() {
final AtomicInteger count = new AtomicInteger(0);
final Multimap<String, Integer> map = mapOfTuples(Tuple.of("1", 1), Tuple.of("2", 2), Tuple.of("3", 3));
final Tuple2<? extends Multimap<String, Integer>, ? extends Multimap<String, Integer>> results = map.partition(entry -> {
count.incrementAndGet();
return true;
});
assertThat(results._1).isEqualTo(mapOfTuples(Tuple.of("1", 1), Tuple.of("2", 2), Tuple.of("3", 3)));
assertThat(results._2).isEmpty();
assertThat(count.get()).isEqualTo(3);
}

// -- put

@Test
Expand Down
18 changes: 17 additions & 1 deletion vavr/src/test/java/io/vavr/collection/AbstractSetTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
*/
package io.vavr.collection;

import io.vavr.Tuple2;
import org.junit.Test;

import java.math.BigDecimal;
import java.util.Spliterator;
import java.util.concurrent.atomic.AtomicInteger;

public abstract class AbstractSetTest extends AbstractTraversableRangeTest {

Expand Down Expand Up @@ -190,6 +192,20 @@ public void shouldRemoveElement() {
assertThat(empty().remove(5)).isEqualTo(empty());
}

// -- partition

@Test
public void shouldPartitionInOneIteration() {
final AtomicInteger count = new AtomicInteger(0);
final Tuple2<? extends Set<Integer>, ? extends Set<Integer>> results = of(1, 2, 3).partition(i -> {
count.incrementAndGet();
return true;
});
assertThat(results._1).isEqualTo(of(1, 2, 3));
assertThat(results._2).isEqualTo(of());
assertThat(count.get()).isEqualTo(3);
}

// -- removeAll

@Test
Expand Down Expand Up @@ -228,7 +244,7 @@ public void shouldReturnSameSetWhenEmptyUnionNonEmpty() {
assertThat(empty().union(set)).isSameAs(set);
}
}

@Test
public void shouldReturnSameSetWhenNonEmptyUnionEmpty() {
final Set<Integer> set = of(1, 2);
Expand Down
15 changes: 15 additions & 0 deletions vavr/src/test/java/io/vavr/collection/ArrayTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collector;
Expand Down Expand Up @@ -245,6 +246,20 @@ public void shouldThrowExceptionWhenGetIndexEqualToLength() {
.isInstanceOf(IndexOutOfBoundsException.class).hasMessage("get(1)");
}

// -- partition

@Test
public void shouldPartitionInOneIteration() {
final AtomicInteger count = new AtomicInteger(0);
final Tuple2<Array<Integer>, Array<Integer>> results = of(1, 2, 3).partition(i -> {
count.incrementAndGet();
return true;
});
assertThat(results._1).isEqualTo(of(1, 2, 3));
assertThat(results._2).isEqualTo(of());
assertThat(count.get()).isEqualTo(3);
}

// -- transform()

@Test
Expand Down
Loading

0 comments on commit d8cf5bb

Please sign in to comment.