Navigation Menu

Skip to content

Commit

Permalink
Fix #2186: make IndexedStateT stack safe (#2187)
Browse files Browse the repository at this point in the history
* Fix #2186: make IndexedStateT stack safe if F[_] is

* IndexedStateT flatMap based on AndThen

* Optimize AndThen for fused operations

* Cleanup

* Cleanup

* Fix style issue

* Cleanup AndThen

* Add benchmark

* Add fusionMaxStackDepth constant

* fixing trailing whitespace
  • Loading branch information
alexandru authored and kailuowang committed Mar 14, 2018
1 parent f644982 commit 246b0e0
Show file tree
Hide file tree
Showing 5 changed files with 262 additions and 3 deletions.
63 changes: 63 additions & 0 deletions bench/src/main/scala/cats/bench/StateTBench.scala
@@ -0,0 +1,63 @@
package cats.bench

import cats.Eval
import cats.data.StateT
import org.openjdk.jmh.annotations._

/**
* To run:
*
* bench/jmh:run -i 10 -wi 10 -f 2 -t 1 cats.bench.StateTBench
*/
@State(Scope.Thread)
@BenchmarkMode(Array(Mode.Throughput))
class StateTBench {
@Param(Array("10"))
var count: Int = _

@Benchmark
def single(): Long = {
randLong.run(32311).value._2
}

@Benchmark
def repeatedLeftBinds(): Int = {
var state = randInt
var i = 0
while (i < count) {
state = state.flatMap(int => randInt.map(_ + int))
i += 1
}
state.run(32312).value._2
}

@Benchmark
def repeatedRightBinds(): Int = {
var state = randInt
var i = 0
while (i < count) {
val oldS = state
state = randInt.flatMap(int => oldS.map(_ + int))
i += 1
}
state.run(32313).value._2
}

def fn(seed: Long): Eval[(Long, Int)] =
Eval.now {
val newSeed = (seed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL
val n = (newSeed >>> 16).toInt
(newSeed, n)
}

val randInt: StateT[Eval, Long, Int] =
StateT(fn)

val randLong: StateT[Eval, Long, Long] =
for {
int1 <- randInt
int2 <- randInt
} yield {
(int1.toLong << 32) | int2
}
}
126 changes: 126 additions & 0 deletions core/src/main/scala/cats/data/AndThen.scala
@@ -0,0 +1,126 @@
package cats.data

import java.io.Serializable

/**
* A function type of a single input that can do function composition
* (via `andThen` and `compose`) in constant stack space with amortized
* linear time application (in the number of constituent functions).
*
* Example:
*
* {{{
* val seed = AndThen((x: Int) => x + 1))
* val f = (0 until 10000).foldLeft(seed)((acc, _) => acc.andThen(_ + 1))
*
* // This should not trigger stack overflow ;-)
* f(0)
* }}}
*/
private[cats] sealed abstract class AndThen[-T, +R]
extends (T => R) with Product with Serializable {

import AndThen._

final def apply(a: T): R =
runLoop(a)

override def andThen[A](g: R => A): AndThen[T, A] = {
// Fusing calls up to a certain threshold, using the fusion
// technique implemented for `cats.effect.IO#map`
this match {
case Single(f, index) if index != fusionMaxStackDepth =>
Single(f.andThen(g), index + 1)
case _ =>
andThenF(AndThen(g))
}
}

override def compose[A](g: A => T): AndThen[A, R] = {
// Fusing calls up to a certain threshold, using the fusion
// technique implemented for `cats.effect.IO#map`
this match {
case Single(f, index) if index != fusionMaxStackDepth =>
Single(f.compose(g), index + 1)
case _ =>
composeF(AndThen(g))
}
}

private def runLoop(start: T): R = {
var self: AndThen[Any, Any] = this.asInstanceOf[AndThen[Any, Any]]
var current: Any = start.asInstanceOf[Any]
var continue = true

while (continue) {
self match {
case Single(f, _) =>
current = f(current)
continue = false

case Concat(Single(f, _), right) =>
current = f(current)
self = right.asInstanceOf[AndThen[Any, Any]]

case Concat(left @ Concat(_, _), right) =>
self = left.rotateAccum(right)
}
}
current.asInstanceOf[R]
}

private final def andThenF[X](right: AndThen[R, X]): AndThen[T, X] =
Concat(this, right)
private final def composeF[X](right: AndThen[X, T]): AndThen[X, R] =
Concat(right, this)

// converts left-leaning to right-leaning
protected final def rotateAccum[E](_right: AndThen[R, E]): AndThen[T, E] = {
var self: AndThen[Any, Any] = this.asInstanceOf[AndThen[Any, Any]]
var right: AndThen[Any, Any] = _right.asInstanceOf[AndThen[Any, Any]]
var continue = true
while (continue) {
self match {
case Concat(left, inner) =>
self = left.asInstanceOf[AndThen[Any, Any]]
right = inner.andThenF(right)

case _ => // Single
self = self.andThenF(right)
continue = false
}
}
self.asInstanceOf[AndThen[T, E]]
}

override def toString: String =
"AndThen$" + System.identityHashCode(this)
}

private[cats] object AndThen {
/** Builds an [[AndThen]] reference by wrapping a plain function. */
def apply[A, B](f: A => B): AndThen[A, B] =
f match {
case ref: AndThen[A, B] @unchecked => ref
case _ => Single(f, 0)
}

private final case class Single[-A, +B](f: A => B, index: Int)
extends AndThen[A, B]
private final case class Concat[-A, E, +B](left: AndThen[A, E], right: AndThen[E, B])
extends AndThen[A, B]

/**
* Establishes the maximum stack depth when fusing `andThen` or
* `compose` calls.
*
* The default is `128`, from which we substract one as an optimization,
* a "!=" comparisson being slightly more efficient than a "<".
*
* This value was reached by taking into account the default stack
* size as set on 32 bits or 64 bits, Linux or Windows systems,
* being enough to notice performance gains, but not big enough
* to be in danger of triggering a stack-overflow error.
*/
private final val fusionMaxStackDepth = 127
}
4 changes: 2 additions & 2 deletions core/src/main/scala/cats/data/IndexedStateT.scala
Expand Up @@ -22,7 +22,7 @@ final class IndexedStateT[F[_], SA, SB, A](val runF: F[SA => F[(SB, A)]]) extend

def flatMap[B, SC](fas: A => IndexedStateT[F, SB, SC, B])(implicit F: FlatMap[F]): IndexedStateT[F, SA, SC, B] =
IndexedStateT.applyF(F.map(runF) { safsba =>
safsba.andThen { fsba =>
AndThen(safsba).andThen { fsba =>
F.flatMap(fsba) { case (sb, a) =>
fas(a).run(sb)
}
Expand All @@ -31,7 +31,7 @@ final class IndexedStateT[F[_], SA, SB, A](val runF: F[SA => F[(SB, A)]]) extend

def flatMapF[B](faf: A => F[B])(implicit F: FlatMap[F]): IndexedStateT[F, SA, SB, B] =
IndexedStateT.applyF(F.map(runF) { sfsa =>
sfsa.andThen { fsa =>
AndThen(sfsa).andThen { fsa =>
F.flatMap(fsa) { case (s, a) => F.map(faf(a))((s, _)) }
}
})
Expand Down
53 changes: 53 additions & 0 deletions tests/src/test/scala/cats/tests/AndThenSuite.scala
@@ -0,0 +1,53 @@
package cats.tests

import catalysts.Platform
import cats.data._

class AndThenSuite extends CatsSuite {
test("compose a chain of functions with andThen") {
check { (i: Int, fs: List[Int => Int]) =>
val result = fs.map(AndThen(_)).reduceOption(_.andThen(_)).map(_(i))
val expect = fs.reduceOption(_.andThen(_)).map(_(i))

result == expect
}
}

test("compose a chain of functions with compose") {
check { (i: Int, fs: List[Int => Int]) =>
val result = fs.map(AndThen(_)).reduceOption(_.compose(_)).map(_(i))
val expect = fs.reduceOption(_.compose(_)).map(_(i))

result == expect
}
}

test("andThen is stack safe") {
val count = if (Platform.isJvm) 500000 else 1000
val fs = (0 until count).map(_ => { i: Int => i + 1 })
val result = fs.foldLeft(AndThen((x: Int) => x))(_.andThen(_))(42)

result shouldEqual (count + 42)
}

test("compose is stack safe") {
val count = if (Platform.isJvm) 500000 else 1000
val fs = (0 until count).map(_ => { i: Int => i + 1 })
val result = fs.foldLeft(AndThen((x: Int) => x))(_.compose(_))(42)

result shouldEqual (count + 42)
}

test("Function1 andThen is stack safe") {
val count = if (Platform.isJvm) 50000 else 1000
val start: (Int => Int) = AndThen((x: Int) => x)
val fs = (0 until count).foldLeft(start) { (acc, _) =>
acc.andThen(_ + 1)
}
fs(0) shouldEqual count
}

test("toString") {
AndThen((x: Int) => x).toString should startWith("AndThen$")
}
}
19 changes: 18 additions & 1 deletion tests/src/test/scala/cats/tests/IndexedStateTSuite.scala
@@ -1,9 +1,9 @@
package cats
package tests

import catalysts.Platform
import cats.arrow.{Profunctor, Strong}
import cats.data.{EitherT, IndexedStateT, State, StateT}

import cats.arrow.Profunctor
import cats.kernel.instances.tuple._
import cats.laws.discipline._
Expand Down Expand Up @@ -251,6 +251,23 @@ class IndexedStateTSuite extends CatsSuite {
got should === (expected)
}

test("flatMap is stack safe on repeated left binds when F is") {
val unit = StateT.pure[Eval, Unit, Unit](())
val count = if (Platform.isJvm) 100000 else 100
val result = (0 until count).foldLeft(unit) { (acc, _) =>
acc.flatMap(_ => unit)
}
result.run(()).value should === (((), ()))
}

test("flatMap is stack safe on repeated right binds when F is") {
val unit = StateT.pure[Eval, Unit, Unit](())
val count = if (Platform.isJvm) 100000 else 100
val result = (0 until count).foldLeft(unit) { (acc, _) =>
unit.flatMap(_ => acc)
}
result.run(()).value should === (((), ()))
}

implicit val iso = SemigroupalTests.Isomorphisms.invariant[IndexedStateT[ListWrapper, String, Int, ?]](IndexedStateT.catsDataFunctorForIndexedStateT(ListWrapper.monad))

Expand Down

0 comments on commit 246b0e0

Please sign in to comment.