Permalink
Branch: master
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
186 lines (137 sloc) 5.04 KB
package fpinscala.state
trait RNG {
def nextInt: (Int, RNG) // Should generate a random `Int`. We'll later define other functions in terms of `nextInt`.
}
object RNG {
// NB - this was called SimpleRNG in the book text
case class Simple(seed: Long) extends RNG {
def nextInt: (Int, RNG) = {
// `&` is bitwise AND. We use the current seed to generate a new seed.
val newSeed = (seed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL
// The next state, which is an `RNG` instance created from the new seed.
val nextRNG = Simple(newSeed)
// `>>>` is right binary shift with zero fill. The value `n` is
// `>>>` our new pseudo-random integer.
val n = (newSeed >>> 16).toInt
// The return value is a tuple containing both a pseudo-random
// integer and the next `RNG` state.
(n, nextRNG)
}
}
type Rand[+A] = RNG => (A, RNG)
val int: Rand[Int] = _.nextInt
def unit[A](a: A): Rand[A] =
rng => (a, rng)
// Map returns a function that eats a random number generator. see double2
// below
def map[A,B](s: Rand[A])(f: A => B): Rand[B] =
rng => {
val (a, rng2) = s(rng)
(f(a), rng2)
}
def nonNegativeInt(rng: RNG): (Int, RNG) = {
val (i, rng2) = rng.nextInt
(if (i < 0) - (i+1) else i, rng2)
}
def double(rng: RNG): (Double, RNG) = {
val (i, rng2) = nonNegativeInt(rng)
(i / Int.MaxValue.toDouble, rng2)
}
def intDouble(rng: RNG): ((Int,Double), RNG) = {
val (i, rng2) = rng.nextInt
val (d, rng3) = double(rng2)
((i, d), rng3)
}
def doubleInt(rng: RNG): ((Double,Int), RNG) = {
val (d, rng2) = double(rng)
val (i, rng3) = rng2.nextInt
((d, i), rng3)
}
def double3(rng: RNG): ((Double,Double,Double), RNG) = {
val (d1, rng2) = double(rng)
val (d2, rng3) = double(rng2)
val (d3, rng4) = double(rng3)
((d1, d2, d3), rng4)
}
def ints(count: Int)(rng: RNG): (List[Int], RNG) =
if (count == 0)
(List(), rng)
else {
val (i, r2) = rng.nextInt
val (is, r) = ints(count - 1)(r2)
(i :: is, r)
}
// Call via: double2(myRng)
val double2: Rand[Double] =
map(nonNegativeInt)(i => i/ Int.MaxValue.toDouble)
def map2[A,B,C](ra: Rand[A], rb: Rand[B])(f: (A, B) => C): Rand[C] =
rng => {
val (a, rng2) = ra(rng)
val (b, rng3) = rb(rng2)
(f(a, b), rng3)
}
def both[A, B](ra: Rand[A], rb: Rand[B]): Rand[(A, B)] =
map2(ra, rb)((_, _))
val randIntDouble: Rand[(Int, Double)] = both(int, double2)
val randDoubleInt: Rand[(Double, Int)] = both(double2, int)
def sequence[A](fs: List[Rand[A]]): Rand[List[A]] =
fs.foldRight(unit(List[A]()))((a, b) => map2(a, b)(_ :: _))
def ints2(n: Int) = sequence(List.fill(n)(int))
def flatMap[A,B](f: Rand[A])(g: A => Rand[B]): Rand[B] =
rng => {
val (a, rng2) = f(rng)
g(a)(rng2) // rng gets magically passed along in second arg list of g
}
def nonNegativeLessThan(n: Int): Rand[Int] =
flatMap(nonNegativeInt) {i=>
val mod = i % n
if (i + (n-1) - mod >= 0) unit(mod) else nonNegativeLessThan(n)
}
def mapViaFlat[A, B](s: Rand[A])(f: A => B): Rand[B] =
flatMap(s)(a => unit(f(a)))
def map2ViaFlat[A,B,C](ra: Rand[A], rb: Rand[B])(f: (A, B) => C): Rand[C] =
flatMap(ra)(a => map(rb)(b => f(a, b)))
} // RNG
// Local import of State companion object so functions are defined in file
import State._
case class State[S,+A](run: S => (A, S)) {
def map[B](f: A => B): State[S, B] =
flatMap(a => unit(f(a)))
def map2[B,C](sb: State[S, B])(f: (A, B) => C): State[S, C] =
flatMap(a => sb.map(b => f(a, b)))
def flatMap[B](f: A => State[S, B]): State[S, B] = State(s => {
val (a, s1) = run(s)
f(a).run(s1)
})
} // State
object State {
type Rand[A] = State[RNG, A]
def unit[S, A](a: A): State[S, A] =
State(s => (a, s))
def sequence[S, A](fs: List[State[S, A]]): State[S, List[A]] =
fs.foldRight(unit[S, List[A]](List[A]()))((a, b) => a.map2(b)(_ :: _))
def sequence2[S, A](fs: List[State[S, A]]): State[S, List[A]] =
fs.reverse.foldLeft(unit[S, List[A]](List[A]()))((a, b) => b.map2(a)(_ :: _))
def modify[S](f: S => S): State[S, Unit] = for {
s <- get // Gets the current state and assigns it to `s`.
_ <- set(f(s)) // Sets the new state to `f` applied to `s`.
} yield ()
def get[S]: State[S, S] = State(s => (s, s))
def set[S](s: S): State[S, Unit] = State(_ => ((), s))
} // State
sealed trait Input
case object Coin extends Input
case object Turn extends Input
case class Machine(locked: Boolean, candies: Int, coins: Int)
object Candy{
def simulateMachine(inputs: List[Input]): State[Machine, (Int, Int)] = for {
_ <- sequence(inputs.map(i => modify((s: Machine) => (i, s) match{
case (_, Machine(_, 0, _)) => s
case (Coin, Machine(false, _, _)) => s
case (Turn, Machine(true, _, _)) => s
case (Coin, Machine(true, candy, coin)) => Machine(false, candy, coin+1)
case (Turn, Machine(false, candy, coin)) => Machine(true, candy-1, coin)
})))
s <- get
} yield (s.coins, s.candies)
}