Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
* Fix #2186: make IndexedStateT stack safe if F[_] is * IndexedStateT flatMap based on AndThen * Optimize AndThen for fused operations * Cleanup * Cleanup * Fix style issue * Cleanup AndThen * Add benchmark * Add fusionMaxStackDepth constant * fixing trailing whitespace
- Loading branch information
1 parent
f644982
commit 246b0e0
Showing
5 changed files
with
262 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
package cats.bench | ||
|
||
import cats.Eval | ||
import cats.data.StateT | ||
import org.openjdk.jmh.annotations._ | ||
|
||
/** | ||
* To run: | ||
* | ||
* bench/jmh:run -i 10 -wi 10 -f 2 -t 1 cats.bench.StateTBench | ||
*/ | ||
@State(Scope.Thread) | ||
@BenchmarkMode(Array(Mode.Throughput)) | ||
class StateTBench { | ||
@Param(Array("10")) | ||
var count: Int = _ | ||
|
||
@Benchmark | ||
def single(): Long = { | ||
randLong.run(32311).value._2 | ||
} | ||
|
||
@Benchmark | ||
def repeatedLeftBinds(): Int = { | ||
var state = randInt | ||
var i = 0 | ||
while (i < count) { | ||
state = state.flatMap(int => randInt.map(_ + int)) | ||
i += 1 | ||
} | ||
state.run(32312).value._2 | ||
} | ||
|
||
@Benchmark | ||
def repeatedRightBinds(): Int = { | ||
var state = randInt | ||
var i = 0 | ||
while (i < count) { | ||
val oldS = state | ||
state = randInt.flatMap(int => oldS.map(_ + int)) | ||
i += 1 | ||
} | ||
state.run(32313).value._2 | ||
} | ||
|
||
def fn(seed: Long): Eval[(Long, Int)] = | ||
Eval.now { | ||
val newSeed = (seed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL | ||
val n = (newSeed >>> 16).toInt | ||
(newSeed, n) | ||
} | ||
|
||
val randInt: StateT[Eval, Long, Int] = | ||
StateT(fn) | ||
|
||
val randLong: StateT[Eval, Long, Long] = | ||
for { | ||
int1 <- randInt | ||
int2 <- randInt | ||
} yield { | ||
(int1.toLong << 32) | int2 | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
package cats.data | ||
|
||
import java.io.Serializable | ||
|
||
/** | ||
* A function type of a single input that can do function composition | ||
* (via `andThen` and `compose`) in constant stack space with amortized | ||
* linear time application (in the number of constituent functions). | ||
* | ||
* Example: | ||
* | ||
* {{{ | ||
* val seed = AndThen((x: Int) => x + 1)) | ||
* val f = (0 until 10000).foldLeft(seed)((acc, _) => acc.andThen(_ + 1)) | ||
* | ||
* // This should not trigger stack overflow ;-) | ||
* f(0) | ||
* }}} | ||
*/ | ||
private[cats] sealed abstract class AndThen[-T, +R] | ||
extends (T => R) with Product with Serializable { | ||
|
||
import AndThen._ | ||
|
||
final def apply(a: T): R = | ||
runLoop(a) | ||
|
||
override def andThen[A](g: R => A): AndThen[T, A] = { | ||
// Fusing calls up to a certain threshold, using the fusion | ||
// technique implemented for `cats.effect.IO#map` | ||
this match { | ||
case Single(f, index) if index != fusionMaxStackDepth => | ||
Single(f.andThen(g), index + 1) | ||
case _ => | ||
andThenF(AndThen(g)) | ||
} | ||
} | ||
|
||
override def compose[A](g: A => T): AndThen[A, R] = { | ||
// Fusing calls up to a certain threshold, using the fusion | ||
// technique implemented for `cats.effect.IO#map` | ||
this match { | ||
case Single(f, index) if index != fusionMaxStackDepth => | ||
Single(f.compose(g), index + 1) | ||
case _ => | ||
composeF(AndThen(g)) | ||
} | ||
} | ||
|
||
private def runLoop(start: T): R = { | ||
var self: AndThen[Any, Any] = this.asInstanceOf[AndThen[Any, Any]] | ||
var current: Any = start.asInstanceOf[Any] | ||
var continue = true | ||
|
||
while (continue) { | ||
self match { | ||
case Single(f, _) => | ||
current = f(current) | ||
continue = false | ||
|
||
case Concat(Single(f, _), right) => | ||
current = f(current) | ||
self = right.asInstanceOf[AndThen[Any, Any]] | ||
|
||
case Concat(left @ Concat(_, _), right) => | ||
self = left.rotateAccum(right) | ||
} | ||
} | ||
current.asInstanceOf[R] | ||
} | ||
|
||
private final def andThenF[X](right: AndThen[R, X]): AndThen[T, X] = | ||
Concat(this, right) | ||
private final def composeF[X](right: AndThen[X, T]): AndThen[X, R] = | ||
Concat(right, this) | ||
|
||
// converts left-leaning to right-leaning | ||
protected final def rotateAccum[E](_right: AndThen[R, E]): AndThen[T, E] = { | ||
var self: AndThen[Any, Any] = this.asInstanceOf[AndThen[Any, Any]] | ||
var right: AndThen[Any, Any] = _right.asInstanceOf[AndThen[Any, Any]] | ||
var continue = true | ||
while (continue) { | ||
self match { | ||
case Concat(left, inner) => | ||
self = left.asInstanceOf[AndThen[Any, Any]] | ||
right = inner.andThenF(right) | ||
|
||
case _ => // Single | ||
self = self.andThenF(right) | ||
continue = false | ||
} | ||
} | ||
self.asInstanceOf[AndThen[T, E]] | ||
} | ||
|
||
override def toString: String = | ||
"AndThen$" + System.identityHashCode(this) | ||
} | ||
|
||
private[cats] object AndThen { | ||
/** Builds an [[AndThen]] reference by wrapping a plain function. */ | ||
def apply[A, B](f: A => B): AndThen[A, B] = | ||
f match { | ||
case ref: AndThen[A, B] @unchecked => ref | ||
case _ => Single(f, 0) | ||
} | ||
|
||
private final case class Single[-A, +B](f: A => B, index: Int) | ||
extends AndThen[A, B] | ||
private final case class Concat[-A, E, +B](left: AndThen[A, E], right: AndThen[E, B]) | ||
extends AndThen[A, B] | ||
|
||
/** | ||
* Establishes the maximum stack depth when fusing `andThen` or | ||
* `compose` calls. | ||
* | ||
* The default is `128`, from which we substract one as an optimization, | ||
* a "!=" comparisson being slightly more efficient than a "<". | ||
* | ||
* This value was reached by taking into account the default stack | ||
* size as set on 32 bits or 64 bits, Linux or Windows systems, | ||
* being enough to notice performance gains, but not big enough | ||
* to be in danger of triggering a stack-overflow error. | ||
*/ | ||
private final val fusionMaxStackDepth = 127 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
package cats.tests | ||
|
||
import catalysts.Platform | ||
import cats.data._ | ||
|
||
class AndThenSuite extends CatsSuite { | ||
test("compose a chain of functions with andThen") { | ||
check { (i: Int, fs: List[Int => Int]) => | ||
val result = fs.map(AndThen(_)).reduceOption(_.andThen(_)).map(_(i)) | ||
val expect = fs.reduceOption(_.andThen(_)).map(_(i)) | ||
|
||
result == expect | ||
} | ||
} | ||
|
||
test("compose a chain of functions with compose") { | ||
check { (i: Int, fs: List[Int => Int]) => | ||
val result = fs.map(AndThen(_)).reduceOption(_.compose(_)).map(_(i)) | ||
val expect = fs.reduceOption(_.compose(_)).map(_(i)) | ||
|
||
result == expect | ||
} | ||
} | ||
|
||
test("andThen is stack safe") { | ||
val count = if (Platform.isJvm) 500000 else 1000 | ||
val fs = (0 until count).map(_ => { i: Int => i + 1 }) | ||
val result = fs.foldLeft(AndThen((x: Int) => x))(_.andThen(_))(42) | ||
|
||
result shouldEqual (count + 42) | ||
} | ||
|
||
test("compose is stack safe") { | ||
val count = if (Platform.isJvm) 500000 else 1000 | ||
val fs = (0 until count).map(_ => { i: Int => i + 1 }) | ||
val result = fs.foldLeft(AndThen((x: Int) => x))(_.compose(_))(42) | ||
|
||
result shouldEqual (count + 42) | ||
} | ||
|
||
test("Function1 andThen is stack safe") { | ||
val count = if (Platform.isJvm) 50000 else 1000 | ||
val start: (Int => Int) = AndThen((x: Int) => x) | ||
val fs = (0 until count).foldLeft(start) { (acc, _) => | ||
acc.andThen(_ + 1) | ||
} | ||
fs(0) shouldEqual count | ||
} | ||
|
||
test("toString") { | ||
AndThen((x: Int) => x).toString should startWith("AndThen$") | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters