Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some FlatMap loops useful for State and Effects #2249

Merged
merged 4 commits into from
May 8, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/src/main/scala/cats/implicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ package cats
object implicits
extends syntax.AllSyntax
with syntax.AllSyntaxBinCompat0
with syntax.AllSyntaxBinCompat1
with instances.AllInstances
3 changes: 3 additions & 0 deletions core/src/main/scala/cats/syntax/all.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package syntax
abstract class AllSyntaxBinCompat
extends AllSyntax
with AllSyntaxBinCompat0
with AllSyntaxBinCompat1

trait AllSyntax
extends AlternativeSyntax
Expand Down Expand Up @@ -57,3 +58,5 @@ trait AllSyntaxBinCompat0
extends UnorderedTraverseSyntax
with ApplicativeErrorExtension
with TrySyntax

trait AllSyntaxBinCompat1 extends FlatMapOptionSyntax
49 changes: 49 additions & 0 deletions core/src/main/scala/cats/syntax/flatMap.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,26 @@ final class FlatMapOps[F[_], A](val fa: F[A]) extends AnyVal {
@deprecated("Use productLEval instead.", "1.0.0-RC2")
def forEffectEval[B](fb: Eval[F[B]])(implicit F: FlatMap[F]): F[A] =
F.productLEval(fa)(fb)

/**
* Like an infinite loop of >> calls. This is most useful effect loops
* that you want to run forever in for instance a server.
*
* This will be an infinite loop, or it will return an F[Nothing].
*
* Be careful using this.
* For instance, a List of length k will produce a list of length k^n at iteration
* n. This means if k = 0, we return an empty list, if k = 1, we loop forever
* allocating single element lists, but if we have a k > 1, we will allocate
* exponentially increasing memory and very quickly OOM.
*/
def foreverM[B](implicit F: FlatMap[F]): F[B] = {
// allocate two things once for efficiency.
val leftUnit = Left(())
val stepResult: F[Either[Unit, B]] = F.map(fa)(_ => leftUnit)
F.tailRecM(())(_ => stepResult)
}

}

final class FlattenOps[F[_], A](val ffa: F[F[A]]) extends AnyVal {
Expand Down Expand Up @@ -101,4 +121,33 @@ final class FlatMapIdOps[A](val a: A) extends AnyVal {
*}}}
*/
def tailRecM[F[_], B](f: A => F[Either[A, B]])(implicit F: FlatMap[F]): F[B] = F.tailRecM(a)(f)

/**
* iterateForeverM is almost exclusively useful for effect types. For instance,
* A may be some state, we may take the current state, run some effect to get
* a new state and repeat.
*/
def iterateForeverM[F[_], B](f: A => F[A])(implicit F: FlatMap[F]): F[B] =
tailRecM[F, B](f.andThen { fa => F.map(fa)(Left(_): Either[A, B]) })
}

trait FlatMapOptionSyntax {
implicit final def catsSyntaxFlatMapOptionOps[F[_]: FlatMap, A](foa: F[Option[A]]): FlatMapOptionOps[F, A] =
new FlatMapOptionOps[F, A](foa)
}

final class FlatMapOptionOps[F[_], A](val fopta: F[Option[A]]) extends AnyVal {
/**
* This repeats an F until we get defined values. This can be useful
* for polling type operations on State (or RNG) Monads, or in effect
* monads.
*/
def untilDefinedM(implicit F: FlatMap[F]): F[A] = {
val leftUnit: Either[Unit, A] = Left(())
val feither: F[Either[Unit, A]] = F.map(fopta) {
case None => leftUnit
case Some(a) => Right(a)
}
F.tailRecM(())(_ => feither)
}
}
4 changes: 2 additions & 2 deletions testkit/src/main/scala/cats/tests/CatsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package tests
import catalysts.Platform

import cats.instances.AllInstances
import cats.syntax.{AllSyntax, AllSyntaxBinCompat0, EqOps}
import cats.syntax.{AllSyntax, AllSyntaxBinCompat0, AllSyntaxBinCompat1, EqOps}

import org.scalactic.anyvals.{PosZDouble, PosInt, PosZInt}
import org.scalatest.{FunSuite, FunSuiteLike, Matchers}
Expand Down Expand Up @@ -36,7 +36,7 @@ trait CatsSuite extends FunSuite
with Discipline
with TestSettings
with AllInstances
with AllSyntax with AllSyntaxBinCompat0
with AllSyntax with AllSyntaxBinCompat0 with AllSyntaxBinCompat1
with StrictCatsEquality { self: FunSuiteLike =>

implicit override val generatorDrivenConfig: PropertyCheckConfiguration =
Expand Down
39 changes: 37 additions & 2 deletions tests/src/test/scala/cats/tests/IndexedStateTSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,13 @@ class IndexedStateTSuite extends CatsSuite {
got should === (expected)
}


private val stackSafeTestSize =
if (Platform.isJvm) 100000 else 100

test("flatMap is stack safe on repeated left binds when F is") {
val unit = StateT.pure[Eval, Unit, Unit](())
val count = if (Platform.isJvm) 100000 else 100
val count = stackSafeTestSize
val result = (0 until count).foldLeft(unit) { (acc, _) =>
acc.flatMap(_ => unit)
}
Expand All @@ -262,13 +266,44 @@ class IndexedStateTSuite extends CatsSuite {

test("flatMap is stack safe on repeated right binds when F is") {
val unit = StateT.pure[Eval, Unit, Unit](())
val count = if (Platform.isJvm) 100000 else 100
val count = stackSafeTestSize
val result = (0 until count).foldLeft(unit) { (acc, _) =>
unit.flatMap(_ => acc)
}
result.run(()).value should === (((), ()))
}

test("untilDefinedM works") {
val counter = State { i: Int =>
val res = if (i > stackSafeTestSize) Some(i) else None
(i + 1, res)
}

counter.untilDefinedM.run(0).value should === ((stackSafeTestSize + 2, stackSafeTestSize + 1))
}

test("foreverM works") {
val step = StateT[Either[Int, ?], Int, Unit] { i =>
if (i > stackSafeTestSize) Left(i) else Right((i + 1, ()))
}
step.foreverM.run(0) match {
case Left(big) => big should === (stackSafeTestSize + 1)
case Right((_, _)) => fail("unreachable code due to Nothing, but scalac won't let us match on it")
}
}

test("iterateForeverM works") {
val result = 0.iterateForeverM { i =>
StateT[Either[Int, ?], Int, Int] { j =>
if (j > stackSafeTestSize) Left(j) else Right((j + 1, i + 1))
}
}
result.run(0) match {
case Left(sum) => sum should === (stackSafeTestSize + 1)
case Right((_, _)) => fail("unreachable code due to Nothing, but scalac won't let us match on it")
}
}

implicit val iso = SemigroupalTests.Isomorphisms.invariant[IndexedStateT[ListWrapper, String, Int, ?]](IndexedStateT.catsDataFunctorForIndexedStateT(ListWrapper.monad))

{
Expand Down