Skip to content

Commit

Permalink
fixing statet to not Stack Overflow when underlying F is a free monad
Browse files Browse the repository at this point in the history
fix #498
  • Loading branch information
vmarquez authored and xuwei-k committed Oct 24, 2015
1 parent 9c37cff commit a07dc36
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 71 deletions.
5 changes: 2 additions & 3 deletions core/src/main/scala/scalaz/Bitraverse.scala
Expand Up @@ -74,11 +74,10 @@ trait Bitraverse[F[_, _]] extends Bifunctor[F] with Bifoldable[F] { self =>
import Free._
implicit val A = StateT.stateTMonadState[S, Trampoline].compose(Applicative[G])

new State[S, G[F[C, D]]] {
def apply(initial: S) = {
State[S, G[F[C, D]]]{
initial =>
val st = bitraverse[λ[α => StateT[Trampoline, S, G[α]]], A, B, C, D](fa)(f(_: A).lift[Trampoline])(g(_: B).lift[Trampoline])
st(initial).run
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/scalaz/Kleisli.scala
Expand Up @@ -69,7 +69,7 @@ final case class Kleisli[M[_], A, B](run: A => M[B]) { self =>
}
)

def state(implicit M: Functor[M]): StateT[M, A, B] =
def state(implicit M: Monad[M]): StateT[M, A, B] =
StateT(a => M.map(run(a))((a, _)))

def liftMK[T[_[_], _]](implicit T: MonadTrans[T], M: Monad[M]): Kleisli[T[M, ?], A, B] =
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/scalaz/ReaderWriterStateT.scala
Expand Up @@ -6,7 +6,7 @@ sealed abstract class IndexedReaderWriterStateT[F[_], -R, W, -S1, S2, A] {
def run(r: R, s: S1): F[(W, A, S2)]

/** Discards the writer component. */
def state(r: R)(implicit F: Functor[F]): IndexedStateT[F, S1, S2, A] =
def state(r: R)(implicit F: Monad[F]): IndexedStateT[F, S1, S2, A] =
IndexedStateT((s: S1) => F.map(run(r, s)) {
case (w, a, s1) => (s1, a)
})
Expand Down
110 changes: 58 additions & 52 deletions core/src/main/scala/scalaz/StateT.scala
Expand Up @@ -3,102 +3,107 @@ package scalaz
import Id._

trait IndexedStateT[F[_], -S1, S2, A] { self =>
def getF[S <: S1]: Monad[F] => F[S => F[(S2, A)]]

/** Run and return the final value and state in the context of `F` */
def apply(initial: S1): F[(S2, A)]
def apply(initial: S1)(implicit F: Monad[F]): F[(S2, A)] =
F.join(F.map[S1 => F[(S2, A)], F[(S2, A)]](getF(F))(sf => sf(initial)))

/** An alias for `apply` */
def run(initial: S1): F[(S2, A)] = apply(initial)
def run(initial: S1)(implicit F: Monad[F]): F[(S2, A)] = apply(initial)

/** Calls `run` using `Monoid[S].zero` as the initial state */
def runZero[S <: S1](implicit S: Monoid[S]): F[(S2, A)] =
def runZero[S <: S1](implicit S: Monoid[S], F: Monad[F]): F[(S2, A)] =
run(S.zero)

/** Run, discard the final state, and return the final value in the context of `F` */
def eval(initial: S1)(implicit F: Functor[F]): F[A] =
F.map(apply(initial))(_._2)
def eval(initial: S1)(implicit F: Monad[F]): F[A] =
F.bind[S1 => F[(S2, A)], A](getF(F))(sf => F.map(sf(initial))(_._2))

/** Calls `eval` using `Monoid[S].zero` as the initial state */
def evalZero[S <: S1](implicit F: Functor[F], S: Monoid[S]): F[A] =
def evalZero[S <: S1](implicit F: Monad[F], S: Monoid[S]): F[A] =
eval(S.zero)

/** Run, discard the final value, and return the final state in the context of `F` */
def exec(initial: S1)(implicit F: Functor[F]): F[S2] =
F.map(apply(initial))(_._1)
def exec(initial: S1)(implicit F: Monad[F]): F[S2] =
F.map(run(initial))(_._1)

/** Calls `exec` using `Monoid[S].zero` as the initial state */
def execZero[S <: S1](implicit F: Functor[F], S: Monoid[S]): F[S2] =
def execZero[S <: S1](implicit F: Monad[F], S: Monoid[S]): F[S2] =
exec(S.zero)

def map[B](f: A => B)(implicit F: Functor[F]): IndexedStateT[F, S1, S2, B] = IndexedStateT(s => F.map(apply(s)) {
case (s1, a) => (s1, f(a))
})
def map[B](f: A => B)(implicit F: Functor[F]): IndexedStateT[F, S1, S2, B] =
mapsf((sf: (S1 => F[(S2, A)])) => (s: S1) => F.map(sf(s))(t => (t._1, f(t._2))))

def xmap[X1, X2](f: S2 => X1)(g: X2 => S1)(implicit F: Functor[F]): IndexedStateT[F, X2, X1, A] = IndexedStateT(s => F.map(apply(g(s))) {
case (s1, a) => (f(s1), a)
})
def xmap[X1, X2](f: S2 => X1)(g: X2 => S1): IndexedStateT[F, X2, X1, A] = IndexedStateT.createState(
(F: Monad[F]) => (x: X2) => F.map(self(g(x))(F))(t => (f(t._1), t._2))
)

/** Map both the return value and final state using the given function. */
def mapK[G[_], B, S](f: F[(S2, A)] => G[(S, B)]): IndexedStateT[G, S1, S, B] =
IndexedStateT { s => f(apply(s)) }
def mapK[G[_], B, S](f: F[(S2, A)] => G[(S, B)])(implicit M: Monad[F]): IndexedStateT[G, S1, S, B] = IndexedStateT.createState(
(m: Monad[G]) => (s: S1) => f(apply(s)(M))
)

import BijectionT._
def bmap[X, S >: S2 <: S1](b: Bijection[S, X])(implicit F: Functor[F]): StateT[F, X, A] =
def bmap[X, S >: S2 <: S1](b: Bijection[S, X]): StateT[F, X, A] =
xmap(b to _)(b from _)

def contramap[X](g: X => S1): IndexedStateT[F, X, S2, A] =
IndexedStateT(s => apply(g(s)))
mapsf(_ compose g)

def imap[X](f: S2 => X)(implicit F: Functor[F]): IndexedStateT[F, S1, X, A] = IndexedStateT(s => F.map(apply(s)) {
case (s1, a) => (f(s1), a)
})
def imap[X](f: S2 => X)(implicit F: Functor[F]): IndexedStateT[F, S1, X, A] = bimap(f)(a => a)

def bimap[X, B](f: S2 => X)(g: A => B)(implicit F: Functor[F]): IndexedStateT[F, S1, X, B] = IndexedStateT(s => F.map(apply(s)) {
case (s1, a) => (f(s1), g(a))
})
def bimap[X, B](f: S2 => X)(g: A => B)(implicit F: Functor[F]): IndexedStateT[F, S1, X, B] = mapsf(sf => (s: S1) => F.map(sf(s))(t => (f(t._1), g(t._2)) ))

def leftMap[X](f: S2 => X)(implicit F: Functor[F]): IndexedStateT[F, S1, X, A] =
imap(f)

def flatMap[S3, B](f: A => IndexedStateT[F, S2, S3, B])(implicit F: Bind[F]): IndexedStateT[F, S1, S3, B] = IndexedStateT(s => F.bind(apply(s)) {
case (s1, a) => f(a)(s1)
})
def flatMap[S3, B](f: A => IndexedStateT[F, S2, S3, B])(implicit F: Monad[F]): IndexedStateT[F, S1, S3, B] =
mapsf(sf => (s: S1) => F.bind[(S2, A), (S3, B)](sf(s)){ t =>
val sfb: F[(S2 => F[(S3, B)])] = f(t._2).getF(F)
F.bind[S2 => F[(S3, B)], (S3, B)](sfb)(ff => ff(t._1))
})

def lift[M[_]: Applicative]: IndexedStateT[λ[α => M[F[α]]], S1, S2, A] =
new IndexedStateT[λ[α => M[F[α]]], S1, S2, A] {
def apply(initial: S1): M[F[(S2, A)]] = Applicative[M].point(self(initial))
}
def lift[M[_]](implicit F: Monad[F], M: Applicative[M]): IndexedStateT[λ[α => M[F[α]]], S1, S2, A] =
IndexedStateT.createState[λ[α => M[F[α]]], S1, S2, A](
(m: Monad[λ[α => M[F[α]]]]) => (s: S1) => M.point(self(s))
)

import Liskov._
def unlift[M[_], FF[_], S <: S1](implicit M: Comonad[M], ev: this.type <~< IndexedStateT[λ[α => M[FF[α]]], S, S2, A]): IndexedStateT[FF, S, S2, A] =
new IndexedStateT[FF, S, S2, A] {
def apply(initial: S): FF[(S2, A)] = Comonad[M].copoint(ev(self)(initial))
def unlift[M[_], FF[_], S <: S1](implicit M: Comonad[M], F: Monad[λ[α => M[FF[α]]]], ev: this.type <~< IndexedStateT[λ[α => M[FF[α]]], S, S2, A]): IndexedStateT[FF, S, S2, A] = IndexedStateT.createState(
(m: Monad[FF]) => (s: S) => {
M.copoint(ev(self)(s))
}
)

def unliftId[M[_], S <: S1](implicit M: Comonad[M], ev: this.type <~< IndexedStateT[M, S, S2, A]): IndexedState[S, S2, A] = unlift[M, Id, S]
def unliftId[M[_], S <: S1](implicit M: Comonad[M], F: Monad[M], ev: this.type <~< IndexedStateT[M, S, S2, A]): IndexedState[S, S2, A] = unlift[M, Id, S]

def rwst[W, R](implicit F: Functor[F], W: Monoid[W]): IndexedReaderWriterStateT[F, R, W, S1, S2, A] =
def rwst[W, R](implicit F: Monad[F], W: Monoid[W]): IndexedReaderWriterStateT[F, R, W, S1, S2, A] =
IndexedReaderWriterStateT(
(r, s) => F.map(self(s)) {
(r, s) => F.bind[S1 => F[(S2, A)], (W, A, S2)] (getF(F))((sf: (S1 => F[(S2, A)])) => F.map(sf(s)) {
case (s, a) => (W.zero, a, s)
}
})
)

def zoom[S0, S3, S <: S1](l: LensFamily[S0, S3, S, S2])(implicit F: Functor[F]): IndexedStateT[F, S0, S3, A] =
new IndexedStateT[F, S0, S3, A] {
def apply(s0: S0) = F.map(self(l get s0)) {
case (s2, a) => (l.set(s0, s2), a)
}
}

mapsf(sf => (s0:S0) => F.map(sf(l get s0))(t => (l.set(s0, t._1), t._2)))

def liftF[S <: S1](implicit F: Functor[IndexedStateT[F, S, S2, ?]]) =
Free.liftF[IndexedStateT[F, S, S2, ?], A](self)

def mapsf[X1, X2, B](f: (S1 => F[(S2, A)]) => (X1 => F[(X2, B)])): IndexedStateT[F, X1, X2, B] =
IndexedStateT.createState((m: Monad[F]) => f((s:S1) => run(s)(m)))
}

object IndexedStateT extends StateTInstances with StateTFunctions {
def apply[F[_], S1, S2, A](f: S1 => F[(S2, A)]): IndexedStateT[F, S1, S2, A] =
def apply[F[_], S1, S2, A](f: S1 => F[(S2, A)])(implicit F: Monad[F]): IndexedStateT[F, S1, S2, A] =
new IndexedStateT[F, S1, S2, A] {
def apply(s: S1) = f(s)
override def getF[S <: S1] = (m: Monad[F]) => F.point(f)
}

def createState[F[_], S1, S2, A](f: Monad[F] => S1 => F[(S2, A)]): IndexedStateT[F, S1, S2, A] =
new IndexedStateT[F, S1, S2, A] {
override def getF[S <: S1] = (m: Monad[F]) => m.point(f(m))
}
}

Expand Down Expand Up @@ -158,14 +163,14 @@ abstract class StateTInstances extends StateTInstances0 {

trait IndexedStateTFunctions {
def constantIndexedStateT[F[_], S1, S2, A](a: A)(s: => S2)(implicit F: Applicative[F]): IndexedStateT[F, S1, S2, A] =
IndexedStateT((_: S1) => F.point((s, a)))
IndexedStateT.createState((m: Monad[F]) => (_: S1) => F.point((s, a)))
}

trait StateTFunctions extends IndexedStateTFunctions {
def constantStateT[F[_], S, A](a: A)(s: => S)(implicit F: Applicative[F]): StateT[F, S, A] =
def constantStateT[F[_], S, A](a: A)(s: => S)(implicit F: Monad[F]): StateT[F, S, A] =
StateT((_: S) => F.point((s, a)))

def stateT[F[_], S, A](a: A)(implicit F: Applicative[F]): StateT[F, S, A] =
def stateT[F[_], S, A](a: A)(implicit F: Monad[F]): StateT[F, S, A] =
StateT(s => F.point((s, a)))
}

Expand Down Expand Up @@ -226,8 +231,9 @@ private trait StateTHoist[S] extends Hoist[λ[(g[_], a) => StateT[g, S, a]]] {
StateT(s => G.map(ga)(a => (s, a)))

def hoist[M[_]: Monad, N[_]](f: M ~> N) = new (StateTF[M, S]#f ~> StateTF[N, S]#f) {
def apply[A](action: StateT[M, S, A]) =
StateT[N, S, A](s => f(action(s)))
def apply[A](action: StateT[M, S, A]) = IndexedStateT.createState(
(n: Monad[N]) => (s: S) => f(action.run(s))
)
}

implicit def apply[G[_] : Monad]: Monad[StateT[G, S, ?]] = StateT.stateTMonadState[S, G]
Expand Down
12 changes: 3 additions & 9 deletions core/src/main/scala/scalaz/package.scala
Expand Up @@ -139,19 +139,13 @@ package object scalaz {
type State[S, A] = StateT[Id, S, A]

object StateT extends StateTInstances with StateTFunctions {
def apply[F[_], S, A](f: S => F[(S, A)]): StateT[F, S, A] = new StateT[F, S, A] {
def apply(s: S) = f(s)
}
def apply[F[_], S, A](f: S => F[(S, A)])(implicit F: Monad[F]): StateT[F, S, A] = IndexedStateT[F, S, S, A](f)
}
object IndexedState extends StateFunctions {
def apply[S1, S2, A](f: S1 => (S2, A)): IndexedState[S1, S2, A] = new IndexedState[S1, S2, A] {
def apply(s: S1) = f(s)
}
def apply[S1, S2, A](f: S1 => (S2, A)): IndexedState[S1, S2, A] = IndexedStateT[Id, S1, S2, A](f)
}
object State extends StateFunctions {
def apply[S, A](f: S => (S, A)): State[S, A] = new StateT[Id, S, A] {
def apply(s: S) = f(s)
}
def apply[S, A](f: S => (S, A)): State[S, A] = StateT[Id, S, A](f)
}

type StoreT[F[_], A, B] = IndexedStoreT[F, A, A, B]
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/scalaz/syntax/StateOps.scala
Expand Up @@ -3,7 +3,7 @@ package syntax

final class StateOps[A](val self: A) extends AnyVal {
def state[S]: State[S, A] = State.state[S, A](self)
def stateT[F[_]:Applicative, S]: StateT[F, S, A] = StateT.stateT[F, S, A](self)
def stateT[F[_]:Monad, S]: StateT[F, S, A] = StateT.stateT[F, S, A](self)
}

trait ToStateOps {
Expand Down
2 changes: 1 addition & 1 deletion effect/src/main/scala/scalaz/effect/LiftIO.scala
Expand Up @@ -56,7 +56,7 @@ object LiftIO {
def liftIO[A](ioa: IO[A]) = WriterT(LiftIO[F].liftIO(ioa.map((Monoid[W].zero, _))))
}

implicit def stateTLiftIO[F[_]: LiftIO, S] =
implicit def stateTLiftIO[F[_]: LiftIO, S](implicit F: Monad[F]) =
new LiftIO[StateT[F, S, ?]] {
def liftIO[A](ioa: IO[A]) = StateT(s => LiftIO[F].liftIO(ioa.map((s, _))))
}
Expand Down
Expand Up @@ -303,10 +303,10 @@ object ScalazArbitrary {
Functor[Arbitrary].map(A)(LazyEitherT[F, A, B](_))

// backwards compatibility
def stateTArb[F[+_], S, A](implicit A: Arbitrary[S => F[(S, A)]]): Arbitrary[StateT[F, S, A]] =
indexedStateTArb[F, S, S, A](A)
def stateTArb[F[+_], S, A](implicit A: Arbitrary[S => F[(S, A)]], F: Monad[F]): Arbitrary[StateT[F, S, A]] =
indexedStateTArb[F, S, S, A](A, F)

implicit def indexedStateTArb[F[_], S1, S2, A](implicit A: Arbitrary[S1 => F[(S2, A)]]): Arbitrary[IndexedStateT[F, S1, S2, A]] =
implicit def indexedStateTArb[F[_], S1, S2, A](implicit A: Arbitrary[S1 => F[(S2, A)]], F: Monad[F]): Arbitrary[IndexedStateT[F, S1, S2, A]] =
Functor[Arbitrary].map(A)(IndexedStateT[F, S1, S2, A](_))

implicit def eitherTArb[F[_], A, B](implicit A: Arbitrary[F[A \/ B]]): Arbitrary[EitherT[F, A, B]] =
Expand Down
7 changes: 7 additions & 0 deletions tests/src/test/scala/scalaz/StateTTest.scala
Expand Up @@ -63,4 +63,11 @@ object StateTTest extends SpecLite {
val b = StateT[List, Int, Boolean](s => List((s, true)))
instances.monadPlus[Int, List].plus(a, b).run(0) must_===(List((0, false), (0, true)))
}

"StateT can be trampolined without stack overflow" in {
import scalaz.Free._
val result = (0 to 4000).toList.map(i => StateT[Trampoline, Int, Int]((ii:Int) => Trampoline.done((i,i))))
.foldLeft(StateT((s:Int) => Trampoline.done((s,s))))( (a,b) => a.flatMap(_ => b))
4000 must_=== result(0).run._1
}
}

0 comments on commit a07dc36

Please sign in to comment.