Skip to content

Commit

Permalink
Experiment: use immutable data structures in optimized traverse
Browse files Browse the repository at this point in the history
Is this necessary? The monad laws should ensure that it's safe to use
mutable builders. Nonetheless it will be good to confirm the performance
delta for using immutable data structures
  • Loading branch information
TimWSpence committed Nov 8, 2023
1 parent a1c9ff5 commit 8cc742e
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 40 deletions.
4 changes: 2 additions & 2 deletions core/src/main/scala-2.13+/cats/instances/arraySeq.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ private[cats] object ArraySeqInstances {
def traverse[G[_], A, B](fa: ArraySeq[A])(f: A => G[B])(implicit G: Applicative[G]): G[ArraySeq[B]] =
G match {
case x: StackSafeMonad[G] =>
x.map(Traverse.traverseDirectly(Vector.newBuilder[B])(fa.iterator)(f)(x))(_.to(ArraySeq.untagged))
x.map(Traverse.traverseDirectly(fa.iterator)(f)(x))(_.to(ArraySeq.untagged))
case _ =>
G.map(Chain.traverseViaChain(fa)(f))(_.iterator.to(ArraySeq.untagged))

Expand Down Expand Up @@ -233,7 +233,7 @@ private[cats] object ArraySeqInstances {
)(f: (A) => G[Option[B]])(implicit G: Applicative[G]): G[ArraySeq[B]] =
G match {
case x: StackSafeMonad[G] =>
x.map(TraverseFilter.traverseFilterDirectly(Vector.newBuilder[B])(fa.iterator)(f)(x))(
x.map(TraverseFilter.traverseFilterDirectly(fa.iterator)(f)(x))(
_.to(ArraySeq.untagged)
)
case _ =>
Expand Down
21 changes: 7 additions & 14 deletions core/src/main/scala/cats/Traverse.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ package cats
import cats.data.State
import cats.data.StateT
import cats.kernel.compat.scalaVersionSpecific._
import cats.StackSafeMonad
import scala.collection.mutable

/**
* Traverse, also known as Traversable.
Expand Down Expand Up @@ -287,21 +285,16 @@ object Traverse {
@deprecated("Use cats.syntax object imports", "2.2.0")
object nonInheritedOps extends ToTraverseOps

private[cats] def traverseDirectly[Coll[x] <: IterableOnce[x], G[_], A, B](
builder: mutable.Builder[B, Coll[B]]
)(fa: IterableOnce[A])(f: A => G[B])(implicit G: StackSafeMonad[G]): G[Coll[B]] = {
val size = fa.knownSize
if (size >= 0) {
builder.sizeHint(size)
}
G.map(fa.iterator.foldLeft(G.pure(builder)) { case (accG, a) =>
private[cats] def traverseDirectly[G[_], A, B](
fa: IterableOnce[A]
)(f: A => G[B])(implicit G: StackSafeMonad[G]): G[Vector[B]] = {
fa.iterator.foldLeft(G.pure(Vector.empty[B])) { case (accG, a) =>
G.flatMap(accG) { acc =>
G.map(f(a)) { a =>
acc += a
acc
G.map(f(a)) { b =>
acc :+ b
}
}
})(_.result())
}
}

private[cats] def traverse_Directly[G[_], A, B](
Expand Down
21 changes: 8 additions & 13 deletions core/src/main/scala/cats/TraverseFilter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import cats.data.State
import cats.kernel.compat.scalaVersionSpecific._

import scala.collection.immutable.{IntMap, TreeSet}
import scala.collection.mutable

/**
* `TraverseFilter`, also known as `Witherable`, represents list-like structures
Expand Down Expand Up @@ -205,21 +204,17 @@ object TraverseFilter {
@deprecated("Use cats.syntax object imports", "2.2.0")
object nonInheritedOps extends ToTraverseFilterOps

private[cats] def traverseFilterDirectly[Coll[x] <: IterableOnce[x], G[_], A, B](
builder: mutable.Builder[B, Coll[B]]
)(fa: IterableOnce[A])(f: A => G[Option[B]])(implicit G: StackSafeMonad[G]): G[Coll[B]] = {
val size = fa.knownSize
if (size >= 0) {
builder.sizeHint(size)
}
G.map(fa.iterator.foldLeft(G.pure(builder)) { case (bldrG, a) =>
G.flatMap(bldrG) { bldr =>
private[cats] def traverseFilterDirectly[G[_], A, B](
fa: IterableOnce[A]
)(f: A => G[Option[B]])(implicit G: StackSafeMonad[G]): G[Vector[B]] = {
fa.iterator.foldLeft(G.pure(Vector.empty[B])) { case (bldrG, a) =>
G.flatMap(bldrG) { acc =>
G.map(f(a)) {
case Some(b) => bldr += b
case None => bldr
case Some(b) => acc :+ b
case None => acc
}
}
})(_.result())
}
}

}
4 changes: 2 additions & 2 deletions core/src/main/scala/cats/data/Chain.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,7 @@ sealed abstract private[data] class ChainInstances extends ChainInstances1 {
else
G match {
case x: StackSafeMonad[G] =>
x.map(Traverse.traverseDirectly(List.newBuilder[B])(fa.iterator)(f)(x))(Chain.fromSeq(_))
x.map(Traverse.traverseDirectly(fa.iterator)(f)(x))(Chain.fromSeq(_))
case _ =>
traverseViaChain {
val as = collection.mutable.ArrayBuffer[A]()
Expand Down Expand Up @@ -1372,7 +1372,7 @@ sealed abstract private[data] class ChainInstances extends ChainInstances1 {
else
G match {
case x: StackSafeMonad[G] =>
G.map(TraverseFilter.traverseFilterDirectly(List.newBuilder[B])(fa.iterator)(f)(x))(Chain.fromSeq(_))
G.map(TraverseFilter.traverseFilterDirectly(fa.iterator)(f)(x))(Chain.fromSeq(_))
case _ =>
traverseFilterViaChain {
val as = collection.mutable.ArrayBuffer[A]()
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/cats/instances/list.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ trait ListInstances extends cats.kernel.instances.ListInstances {
if (fa.isEmpty) G.pure(Nil)
else
G match {
case x: StackSafeMonad[G] => Traverse.traverseDirectly[List, G, A, B](ListBuffer.empty[B])(fa)(f)(x)
case x: StackSafeMonad[G] => G.map(Traverse.traverseDirectly[G, A, B](fa)(f)(x))(_.toList)
case _ =>
G.map(Chain.traverseViaChain {
val as = collection.mutable.ArrayBuffer[A]()
Expand Down Expand Up @@ -320,7 +320,7 @@ private[instances] trait ListInstancesBinCompat0 {
if (fa.isEmpty) G.pure(Nil)
else
G match {
case x: StackSafeMonad[G] => TraverseFilter.traverseFilterDirectly(List.newBuilder[B])(fa)(f)(x)
case x: StackSafeMonad[G] => x.map(TraverseFilter.traverseFilterDirectly(fa)(f)(x))(_.toList)
case _ =>
G.map(Chain.traverseFilterViaChain {
val as = collection.mutable.ArrayBuffer[A]()
Expand Down
8 changes: 5 additions & 3 deletions core/src/main/scala/cats/instances/queue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ trait QueueInstances extends cats.kernel.instances.QueueInstances {
if (fa.isEmpty) G.pure(Queue.empty[B])
else
G match {
case x: StackSafeMonad[G] => Traverse.traverseDirectly(Queue.newBuilder[B])(fa)(f)(x)
case x: StackSafeMonad[G] =>
G.map(Traverse.traverseDirectly(fa)(f)(x))(fromIterableOnce(_))
case _ =>
G.map(Chain.traverseViaChain {
val as = collection.mutable.ArrayBuffer[A]()
Expand Down Expand Up @@ -222,7 +223,7 @@ trait QueueInstances extends cats.kernel.instances.QueueInstances {
@suppressUnusedImportWarningForScalaVersionSpecific
private object QueueInstances {
private val catsStdTraverseFilterForQueue: TraverseFilter[Queue] = new TraverseFilter[Queue] {
val traverse: Traverse[Queue] = cats.instances.queue.catsStdInstancesForQueue
val traverse: Traverse[Queue] with Alternative[Queue] = cats.instances.queue.catsStdInstancesForQueue

override def mapFilter[A, B](fa: Queue[A])(f: (A) => Option[B]): Queue[B] =
fa.collect(Function.unlift(f))
Expand All @@ -239,7 +240,8 @@ private object QueueInstances {
if (fa.isEmpty) G.pure(Queue.empty[B])
else
G match {
case x: StackSafeMonad[G] => TraverseFilter.traverseFilterDirectly(Queue.newBuilder[B])(fa)(f)(x)
case x: StackSafeMonad[G] =>
x.map(TraverseFilter.traverseFilterDirectly(fa)(f)(x))(traverse.fromIterableOnce(_))
case _ =>
G.map(Chain.traverseFilterViaChain {
val as = collection.mutable.ArrayBuffer[A]()
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/cats/instances/seq.scala
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ trait SeqInstances extends cats.kernel.instances.SeqInstances {
final override def traverse[G[_], A, B](fa: Seq[A])(f: A => G[B])(implicit G: Applicative[G]): G[Seq[B]] =
G match {
case x: StackSafeMonad[G] =>
Traverse.traverseDirectly(Seq.newBuilder[B])(fa)(f)(x)
x.map(Traverse.traverseDirectly(fa)(f)(x))(_.toSeq)
case _ =>
G.map(Chain.traverseViaChain(fa.toIndexedSeq)(f))(_.toVector)
}
Expand Down Expand Up @@ -210,7 +210,7 @@ trait SeqInstances extends cats.kernel.instances.SeqInstances {

def traverseFilter[G[_], A, B](fa: Seq[A])(f: (A) => G[Option[B]])(implicit G: Applicative[G]): G[Seq[B]] =
G match {
case x: StackSafeMonad[G] => TraverseFilter.traverseFilterDirectly(Seq.newBuilder[B])(fa)(f)(x)
case x: StackSafeMonad[G] => x.map(TraverseFilter.traverseFilterDirectly(fa)(f)(x))(_.toSeq)
case _ =>
G.map(Chain.traverseFilterViaChain(fa.toIndexedSeq)(f))(_.toVector)
}
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/cats/instances/vector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ trait VectorInstances extends cats.kernel.instances.VectorInstances {

final override def traverse[G[_], A, B](fa: Vector[A])(f: A => G[B])(implicit G: Applicative[G]): G[Vector[B]] =
G match {
case x: StackSafeMonad[G] => Traverse.traverseDirectly(Vector.newBuilder[B])(fa)(f)(x)
case x: StackSafeMonad[G] => Traverse.traverseDirectly(fa)(f)(x)
case _ => G.map(Chain.traverseViaChain(fa)(f))(_.toVector)
}

Expand Down Expand Up @@ -271,7 +271,7 @@ private[instances] trait VectorInstancesBinCompat0 {

def traverseFilter[G[_], A, B](fa: Vector[A])(f: (A) => G[Option[B]])(implicit G: Applicative[G]): G[Vector[B]] =
G match {
case x: StackSafeMonad[G] => TraverseFilter.traverseFilterDirectly(Vector.newBuilder[B])(fa)(f)(x)
case x: StackSafeMonad[G] => TraverseFilter.traverseFilterDirectly(fa)(f)(x)
case _ =>
G.map(Chain.traverseFilterViaChain(fa)(f))(_.toVector)
}
Expand Down

0 comments on commit 8cc742e

Please sign in to comment.