diff --git a/bench/src/main/scala/cats/bench/StateTBench.scala b/bench/src/main/scala/cats/bench/StateTBench.scala new file mode 100644 index 0000000000..f193edf146 --- /dev/null +++ b/bench/src/main/scala/cats/bench/StateTBench.scala @@ -0,0 +1,63 @@ +package cats.bench + +import cats.Eval +import cats.data.StateT +import org.openjdk.jmh.annotations._ + +/** + * To run: + * + * bench/jmh:run -i 10 -wi 10 -f 2 -t 1 cats.bench.StateTBench + */ +@State(Scope.Thread) +@BenchmarkMode(Array(Mode.Throughput)) +class StateTBench { + @Param(Array("10")) + var count: Int = _ + + @Benchmark + def single(): Long = { + randLong.run(32311).value._2 + } + + @Benchmark + def repeatedLeftBinds(): Int = { + var state = randInt + var i = 0 + while (i < count) { + state = state.flatMap(int => randInt.map(_ + int)) + i += 1 + } + state.run(32312).value._2 + } + + @Benchmark + def repeatedRightBinds(): Int = { + var state = randInt + var i = 0 + while (i < count) { + val oldS = state + state = randInt.flatMap(int => oldS.map(_ + int)) + i += 1 + } + state.run(32313).value._2 + } + + def fn(seed: Long): Eval[(Long, Int)] = + Eval.now { + val newSeed = (seed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL + val n = (newSeed >>> 16).toInt + (newSeed, n) + } + + val randInt: StateT[Eval, Long, Int] = + StateT(fn) + + val randLong: StateT[Eval, Long, Long] = + for { + int1 <- randInt + int2 <- randInt + } yield { + (int1.toLong << 32) | int2 + } +} diff --git a/core/src/main/scala/cats/data/AndThen.scala b/core/src/main/scala/cats/data/AndThen.scala new file mode 100644 index 0000000000..4a383ba38b --- /dev/null +++ b/core/src/main/scala/cats/data/AndThen.scala @@ -0,0 +1,126 @@ +package cats.data + +import java.io.Serializable + +/** + * A function type of a single input that can do function composition + * (via `andThen` and `compose`) in constant stack space with amortized + * linear time application (in the number of constituent functions). + * + * Example: + * + * {{{ + * val seed = AndThen((x: Int) => x + 1)) + * val f = (0 until 10000).foldLeft(seed)((acc, _) => acc.andThen(_ + 1)) + * + * // This should not trigger stack overflow ;-) + * f(0) + * }}} + */ +private[cats] sealed abstract class AndThen[-T, +R] + extends (T => R) with Product with Serializable { + + import AndThen._ + + final def apply(a: T): R = + runLoop(a) + + override def andThen[A](g: R => A): AndThen[T, A] = { + // Fusing calls up to a certain threshold, using the fusion + // technique implemented for `cats.effect.IO#map` + this match { + case Single(f, index) if index != fusionMaxStackDepth => + Single(f.andThen(g), index + 1) + case _ => + andThenF(AndThen(g)) + } + } + + override def compose[A](g: A => T): AndThen[A, R] = { + // Fusing calls up to a certain threshold, using the fusion + // technique implemented for `cats.effect.IO#map` + this match { + case Single(f, index) if index != fusionMaxStackDepth => + Single(f.compose(g), index + 1) + case _ => + composeF(AndThen(g)) + } + } + + private def runLoop(start: T): R = { + var self: AndThen[Any, Any] = this.asInstanceOf[AndThen[Any, Any]] + var current: Any = start.asInstanceOf[Any] + var continue = true + + while (continue) { + self match { + case Single(f, _) => + current = f(current) + continue = false + + case Concat(Single(f, _), right) => + current = f(current) + self = right.asInstanceOf[AndThen[Any, Any]] + + case Concat(left @ Concat(_, _), right) => + self = left.rotateAccum(right) + } + } + current.asInstanceOf[R] + } + + private final def andThenF[X](right: AndThen[R, X]): AndThen[T, X] = + Concat(this, right) + private final def composeF[X](right: AndThen[X, T]): AndThen[X, R] = + Concat(right, this) + + // converts left-leaning to right-leaning + protected final def rotateAccum[E](_right: AndThen[R, E]): AndThen[T, E] = { + var self: AndThen[Any, Any] = this.asInstanceOf[AndThen[Any, Any]] + var right: AndThen[Any, Any] = _right.asInstanceOf[AndThen[Any, Any]] + var continue = true + while (continue) { + self match { + case Concat(left, inner) => + self = left.asInstanceOf[AndThen[Any, Any]] + right = inner.andThenF(right) + + case _ => // Single + self = self.andThenF(right) + continue = false + } + } + self.asInstanceOf[AndThen[T, E]] + } + + override def toString: String = + "AndThen$" + System.identityHashCode(this) +} + +private[cats] object AndThen { + /** Builds an [[AndThen]] reference by wrapping a plain function. */ + def apply[A, B](f: A => B): AndThen[A, B] = + f match { + case ref: AndThen[A, B] @unchecked => ref + case _ => Single(f, 0) + } + + private final case class Single[-A, +B](f: A => B, index: Int) + extends AndThen[A, B] + private final case class Concat[-A, E, +B](left: AndThen[A, E], right: AndThen[E, B]) + extends AndThen[A, B] + + /** + * Establishes the maximum stack depth when fusing `andThen` or + * `compose` calls. + * + * The default is `128`, from which we substract one as an optimization, + * a "!=" comparisson being slightly more efficient than a "<". + * + * This value was reached by taking into account the default stack + * size as set on 32 bits or 64 bits, Linux or Windows systems, + * being enough to notice performance gains, but not big enough + * to be in danger of triggering a stack-overflow error. + */ + private final val fusionMaxStackDepth = 127 +} diff --git a/core/src/main/scala/cats/data/IndexedStateT.scala b/core/src/main/scala/cats/data/IndexedStateT.scala index d6b6c2e4d3..92598380af 100644 --- a/core/src/main/scala/cats/data/IndexedStateT.scala +++ b/core/src/main/scala/cats/data/IndexedStateT.scala @@ -22,7 +22,7 @@ final class IndexedStateT[F[_], SA, SB, A](val runF: F[SA => F[(SB, A)]]) extend def flatMap[B, SC](fas: A => IndexedStateT[F, SB, SC, B])(implicit F: FlatMap[F]): IndexedStateT[F, SA, SC, B] = IndexedStateT.applyF(F.map(runF) { safsba => - safsba.andThen { fsba => + AndThen(safsba).andThen { fsba => F.flatMap(fsba) { case (sb, a) => fas(a).run(sb) } @@ -31,7 +31,7 @@ final class IndexedStateT[F[_], SA, SB, A](val runF: F[SA => F[(SB, A)]]) extend def flatMapF[B](faf: A => F[B])(implicit F: FlatMap[F]): IndexedStateT[F, SA, SB, B] = IndexedStateT.applyF(F.map(runF) { sfsa => - sfsa.andThen { fsa => + AndThen(sfsa).andThen { fsa => F.flatMap(fsa) { case (s, a) => F.map(faf(a))((s, _)) } } }) diff --git a/tests/src/test/scala/cats/tests/AndThenSuite.scala b/tests/src/test/scala/cats/tests/AndThenSuite.scala new file mode 100644 index 0000000000..eaabcaab59 --- /dev/null +++ b/tests/src/test/scala/cats/tests/AndThenSuite.scala @@ -0,0 +1,53 @@ +package cats.tests + +import catalysts.Platform +import cats.data._ + +class AndThenSuite extends CatsSuite { + test("compose a chain of functions with andThen") { + check { (i: Int, fs: List[Int => Int]) => + val result = fs.map(AndThen(_)).reduceOption(_.andThen(_)).map(_(i)) + val expect = fs.reduceOption(_.andThen(_)).map(_(i)) + + result == expect + } + } + + test("compose a chain of functions with compose") { + check { (i: Int, fs: List[Int => Int]) => + val result = fs.map(AndThen(_)).reduceOption(_.compose(_)).map(_(i)) + val expect = fs.reduceOption(_.compose(_)).map(_(i)) + + result == expect + } + } + + test("andThen is stack safe") { + val count = if (Platform.isJvm) 500000 else 1000 + val fs = (0 until count).map(_ => { i: Int => i + 1 }) + val result = fs.foldLeft(AndThen((x: Int) => x))(_.andThen(_))(42) + + result shouldEqual (count + 42) + } + + test("compose is stack safe") { + val count = if (Platform.isJvm) 500000 else 1000 + val fs = (0 until count).map(_ => { i: Int => i + 1 }) + val result = fs.foldLeft(AndThen((x: Int) => x))(_.compose(_))(42) + + result shouldEqual (count + 42) + } + + test("Function1 andThen is stack safe") { + val count = if (Platform.isJvm) 50000 else 1000 + val start: (Int => Int) = AndThen((x: Int) => x) + val fs = (0 until count).foldLeft(start) { (acc, _) => + acc.andThen(_ + 1) + } + fs(0) shouldEqual count + } + + test("toString") { + AndThen((x: Int) => x).toString should startWith("AndThen$") + } +} \ No newline at end of file diff --git a/tests/src/test/scala/cats/tests/IndexedStateTSuite.scala b/tests/src/test/scala/cats/tests/IndexedStateTSuite.scala index 015e75b3fd..49a1f53e1a 100644 --- a/tests/src/test/scala/cats/tests/IndexedStateTSuite.scala +++ b/tests/src/test/scala/cats/tests/IndexedStateTSuite.scala @@ -1,9 +1,9 @@ package cats package tests +import catalysts.Platform import cats.arrow.{Profunctor, Strong} import cats.data.{EitherT, IndexedStateT, State, StateT} - import cats.arrow.Profunctor import cats.kernel.instances.tuple._ import cats.laws.discipline._ @@ -251,6 +251,23 @@ class IndexedStateTSuite extends CatsSuite { got should === (expected) } + test("flatMap is stack safe on repeated left binds when F is") { + val unit = StateT.pure[Eval, Unit, Unit](()) + val count = if (Platform.isJvm) 100000 else 100 + val result = (0 until count).foldLeft(unit) { (acc, _) => + acc.flatMap(_ => unit) + } + result.run(()).value should === (((), ())) + } + + test("flatMap is stack safe on repeated right binds when F is") { + val unit = StateT.pure[Eval, Unit, Unit](()) + val count = if (Platform.isJvm) 100000 else 100 + val result = (0 until count).foldLeft(unit) { (acc, _) => + unit.flatMap(_ => acc) + } + result.run(()).value should === (((), ())) + } implicit val iso = SemigroupalTests.Isomorphisms.invariant[IndexedStateT[ListWrapper, String, Int, ?]](IndexedStateT.catsDataFunctorForIndexedStateT(ListWrapper.monad))