Skip to content

Commit

Permalink
add FiberRef
Browse files Browse the repository at this point in the history
  • Loading branch information
hanny24 committed May 18, 2019
1 parent 0113310 commit 99fbc89
Show file tree
Hide file tree
Showing 10 changed files with 296 additions and 125 deletions.
48 changes: 5 additions & 43 deletions core/jvm/src/test/scala/scalaz/zio/FiberLocalSpec.scala
Original file line number Diff line number Diff line change
@@ -1,23 +1,15 @@
package scalaz.zio

import java.util.concurrent.TimeUnit

import scalaz.zio.duration.Duration

class FiberLocalSpec(implicit ee: org.specs2.concurrent.ExecutionEnv) extends TestRuntime {

//retrieve fiber-local data that has been set $e1
//empty fiber-local data $e2
//automatically sets and frees data $e3
//fiber-local data cannot be accessed by other fibers $e4
//setting does not overwrite existing fiber-local data $e5

def is =
"FiberLocalSpec".title ^ s2"""
Create a new FiberLocal and

child sees data written by parent $e6
parent doesn't see writes by child $e7
retrieve fiber-local data that has been set $e1
empty fiber-local data $e2
automatically sets and frees data $e3
fiber-local data cannot be accessed by other fibers $e4
setting does not overwrite existing fiber-local data $e5
"""

def e1 = unsafeRun(
Expand Down Expand Up @@ -67,34 +59,4 @@ class FiberLocalSpec(implicit ee: org.specs2.concurrent.ExecutionEnv) extends Te
} yield (v1 must_=== Some(10)) and (v2 must_== Some(20))
)

def e6 = unsafeRun(
for {
local <- FiberLocal.make[Int]
_ <- local.set(10)
f <- local.get.fork
v1 <- f.join
} yield v1 must_=== Some(10)
)

def e7 = unsafeRun(
for {
local <- FiberLocal.make[Int]
_ <- local.set(42)
f <- (local.set(10) *> local.get).fork
v1 <- local.get
v2 <- f.join
} yield (v1 must_=== Some(42)) and (v2 must_=== Some(10))
)

def e8 = unsafeRun(
for {
local <- FiberLocal.make[Int]
f1 <- local.set(10).fork
f2 <- local.set(20).fork
f = f1.zip(f2)
_ <- ZIO.inheritLocals(f)
v <- local.get
} yield v must_=== ???
)

}
20 changes: 19 additions & 1 deletion core/shared/src/main/scala/scalaz/zio/Fiber.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ trait Fiber[+E, +A] { self =>
* fiber has been determined. Attempting to join a fiber that has errored will
* result in a catchable error, _if_ that error does not result from interruption.
*/
final def join: IO[E, A] = await.flatMap(IO.done)
final def join: IO[E, A] = await.flatMap(IO.done) <* inheritLocals

/**
* Interrupts the fiber with no specified reason. If the fiber has already
Expand All @@ -67,6 +67,12 @@ trait Fiber[+E, +A] { self =>
*/
def interrupt: UIO[Exit[E, A]]

/**
* Inherits values from all [[FiberLocal]] instances into current fiber.
* This will resume immediately.
*/
def inheritLocals: UIO[Unit]

/**
* Returns a fiber that prefers `this` fiber, but falls back to the
* `that` one when `this` one fails.
Expand All @@ -86,6 +92,9 @@ trait Fiber[+E, +A] { self =>

def interrupt: UIO[Exit[E1, A1]] =
self.interrupt *> that.interrupt

def inheritLocals: UIO[Unit] =
that.inheritLocals *> self.inheritLocals
}

/**
Expand All @@ -105,6 +114,8 @@ trait Fiber[+E, +A] { self =>
}

def interrupt: UIO[Exit[E1, C]] = self.interrupt.zipWith(that.interrupt)(_.zipWith(_)(f, _ && _))

def inheritLocals: UIO[Unit] = that.inheritLocals *> self.inheritLocals
}

/**
Expand Down Expand Up @@ -152,6 +163,7 @@ trait Fiber[+E, +A] { self =>
def await: UIO[Exit[E, B]] = self.await.map(_.map(f))
def poll: UIO[Option[Exit[E, B]]] = self.poll.map(_.map(_.map(f)))
def interrupt: UIO[Exit[E, B]] = self.interrupt.map(_.map(f))
def inheritLocals: UIO[Unit] = self.inheritLocals
}

/**
Expand Down Expand Up @@ -232,6 +244,7 @@ object Fiber {
def await: UIO[Exit[Nothing, Nothing]] = IO.never
def poll: UIO[Option[Exit[Nothing, Nothing]]] = IO.succeed(None)
def interrupt: UIO[Exit[Nothing, Nothing]] = IO.never
def inheritLocals: UIO[Unit] = IO.unit
}

/**
Expand All @@ -242,6 +255,8 @@ object Fiber {
def await: UIO[Exit[E, A]] = IO.succeedLazy(exit)
def poll: UIO[Option[Exit[E, A]]] = IO.succeedLazy(Some(exit))
def interrupt: UIO[Exit[E, A]] = IO.succeedLazy(exit)
def inheritLocals: UIO[Unit] = IO.unit

}

/**
Expand Down Expand Up @@ -295,5 +310,8 @@ object Fiber {
def poll: UIO[Option[Exit[Throwable, A]]] = IO.effectTotal(ftr.value.map(Exit.fromTry))

def interrupt: UIO[Exit[Throwable, A]] = join.fold(Exit.fail, Exit.succeed)

def inheritLocals: UIO[Unit] = IO.unit

}
}
27 changes: 11 additions & 16 deletions core/shared/src/main/scala/scalaz/zio/FiberLocal.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package scalaz.zio

import FiberLocal.internal._
import scalaz.zio

/**
* A container for fiber-local storage. It is the pure equivalent to Java's `ThreadLocal`
Expand All @@ -30,26 +29,26 @@ final class FiberLocal[A] private (private val state: Ref[State[A]]) extends Ser
*/
final def get: UIO[Option[A]] =
for {
maybeFiberId <- new ZIO.GetLocal(this)
value <- state.get
} yield maybeFiberId.flatMap(value.get)
descriptor <- IO.descriptor
value <- state.get
} yield value.get(descriptor.id)

/**
* Sets the value associated with the current fiber.
*/
final def set(value: A): UIO[Unit] =
for {
fiberId <- new ZIO.SetLocal(this)
_ <- state.update(_ + (fiberId -> value))
descriptor <- IO.descriptor
_ <- state.update(_ + (descriptor.id -> value))
} yield ()

/**
* Empties the value associated with the current fiber.
*/
final def empty: UIO[Unit] =
for {
fiberId <- new ZIO.SetLocal(this)
_ <- state.update(_ - fiberId)
descriptor <- IO.descriptor
_ <- state.update(_ - descriptor.id)
} yield ()

/**
Expand All @@ -66,14 +65,10 @@ object FiberLocal {
/**
* Creates a new `FiberLocal`.
*/
@deprecated("Use explicit make[A](initValue).", "1.0-RC5")
final def make[A]: UIO[FiberLocal[Option[A]]] = make(None)

/**
* Creates a new `FiberLocal`.
*/
final def make[A](initValue: A): UIO[FiberLocal[A]] =
???
final def make[A]: UIO[FiberLocal[A]] =
Ref
.make[internal.State[A]](Map())
.map(state => new FiberLocal(state))

private[zio] object internal {
type State[A] = Map[FiberId, A]
Expand Down
65 changes: 65 additions & 0 deletions core/shared/src/main/scala/scalaz/zio/FiberRef.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package scalaz.zio

import scalaz.zio

/**
* TODO: improve
* Fiber's counterpart for [[ThreadLocal]]. Value is automatically propagated
* to child on fork and merged back in after joining child.
* {{{
* for {
* fiberRef <- FiberRef.make("Hello world!")
* child <- fiberRef.set("Hi!).fork
* result <- child.join
* } yield result
* }}}
*
* `result` will be equal to "Hi!" as changes done by child were merged on join.
*
* @param initial
* @tparam A
*/
final class FiberRef[A](private[zio] val initial: A) extends Serializable {

private[this] val read = new ZIO.FiberRefGet[A](this)

private def write(value: A, fiberId: FiberId) = new ZIO.FiberRefSet[A](this, value, fiberId)

/**
* Reads the value associated with the current fiber. Returns initial value if
* no value was `set` or inherited from parent.
*/
final val get: UIO[A] = read.map(_.map(_._1).getOrElse(initial))

/**
* Sets the value associated with the current fiber.
*/
final def set(value: A): UIO[Unit] =
for {
descriptor <- ZIO.descriptor
_ <- write(value, descriptor.id)
} yield ()

/**
* Returns an `IO` that runs with `value` bound to the current fiber.
*
* Guarantees that fiber data is properly restored via `bracket`.
*/
final def locally[R, E, B](value: A)(use: ZIO[R, E, B]): ZIO[R, E, B] = {
// let's write initial value to fiber's locals map if there is no record
val readWithDefault = read.flatMap {
case Some(pair) => ZIO.succeed(pair)
case None => ZIO.descriptor.map(descriptor => (initial, descriptor.id))
}
readWithDefault.bracket(pair => write(pair._1, pair._2))(_ => set(value) *> use)
}

}

object FiberRef extends Serializable {

/**
* Creates a new `FiberRef` with given initial value.
*/
def make[A](initialValue: A): UIO[FiberRef[A]] = new zio.ZIO.FiberRefNew(initialValue)
}
4 changes: 2 additions & 2 deletions core/shared/src/main/scala/scalaz/zio/Runtime.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package scalaz.zio

import scalaz.zio.internal.{ FiberContext, Platform }
import scalaz.zio.internal.{FiberContext, Platform}

/**
* A `Runtime[R]` is capable of executing tasks within an environment `R`.
Expand Down Expand Up @@ -65,7 +65,7 @@ trait Runtime[+R] {
* This method is effectful and should only be invoked at the edges of your program.
*/
final def unsafeRunAsync[E, A](zio: ZIO[R, E, A])(k: Exit[E, A] => Unit): Unit = {
val context = new FiberContext[E, A](Platform, Environment.asInstanceOf[AnyRef])
val context = new FiberContext[E, A](Platform, Environment.asInstanceOf[AnyRef], Platform.newWeakHashMap())

context.evaluateNow(zio.asInstanceOf[IO[E, A]])
context.runAsync(k)
Expand Down
31 changes: 17 additions & 14 deletions core/shared/src/main/scala/scalaz/zio/ZIO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,12 @@ sealed trait ZIO[-R, +E, +A] extends Serializable { self =>
exit.foldM[E1, Either[A, B]](
_ => right.join.map(Right(_)),
a => ZIO.succeedLeft(a) <* right.interrupt
),
),
(exit, left) =>
exit.foldM[E1, Either[A, B]](
_ => left.join.map(Left(_)),
b => ZIO.succeedRight(b) <* left.interrupt
)
)
)

/**
Expand Down Expand Up @@ -495,7 +495,7 @@ sealed trait ZIO[-R, +E, +A] extends Serializable { self =>
eb match {
case Exit.Failure(_) => release(a)
case _ => ZIO.unit
}
}
)(use)

/**
Expand All @@ -514,7 +514,7 @@ sealed trait ZIO[-R, +E, +A] extends Serializable { self =>
eb match {
case Exit.Success(_) => ZIO.unit
case Exit.Failure(cause) => cleanup(cause)
}
}
)(_ => self)

/**
Expand All @@ -535,7 +535,7 @@ sealed trait ZIO[-R, +E, +A] extends Serializable { self =>
eb match {
case Exit.Failure(cause) => cause.failureOrCause.fold(_ => ZIO.unit, cleanup)
case _ => ZIO.unit
}
}
)(_ => self)

/**
Expand Down Expand Up @@ -797,7 +797,7 @@ sealed trait ZIO[-R, +E, +A] extends Serializable { self =>
schedule.update(a, state).flatMap { step =>
if (!step.cont) ZIO.succeedRight(step.finish())
else ZIO.succeed(step.state).delay(step.delay).flatMap(s => loop(Some(step.finish), s))
}
}
)

schedule.initial.flatMap(loop(None, _))
Expand Down Expand Up @@ -841,7 +841,7 @@ sealed trait ZIO[-R, +E, +A] extends Serializable { self =>
decision =>
if (decision.cont) clock.sleep(decision.delay) *> loop(decision.state)
else orElse(err, decision.finish()).map(Left(_))
),
),
succ => ZIO.succeedRight(succ)
)

Expand Down Expand Up @@ -1955,6 +1955,10 @@ object ZIO extends ZIO_R_Any {
final val Yield = 14
final val Access = 15
final val Provide = 16
final val FiberRefNew = 17
final val FiberRefGet = 18
final val FiberRefSet = 19

}
private[zio] final class FlatMap[R, E, A0, A](val zio: ZIO[R, E, A0], val k: A0 => ZIO[R, E, A])
extends ZIO[R, E, A] {
Expand Down Expand Up @@ -2046,16 +2050,15 @@ object ZIO extends ZIO_R_Any {
override def tag = Tags.Provide
}

final class InheritLocals(val fiber: Fiber[_, _]) extends UIO[Unit] {
override def tag = Tags.InheritLocals
private[zio] final class FiberRefNew[A](val initialValue: A) extends UIO[FiberRef[A]] {
override def tag = Tags.FiberRefNew
}

final class GetLocal(val fiberLocal: FiberLocal[_]) extends UIO[Option[FiberId]] {
override def tag: Int = Tags.GetLocal
private[zio] final class FiberRefGet[A](val fiberRef: FiberRef[A]) extends UIO[Option[(A, FiberId)]] {
override def tag = Tags.FiberRefGet
}

final class SetLocal(val fiberLocal: FiberLocal[_]) extends UIO[FiberId] {
override def tag: Int = Tags.SetLocal

private[zio] final class FiberRefSet[A](val fiberRef: FiberRef[A], val value: A, val fiberId: FiberId) extends UIO[Unit] {
override def tag = Tags.FiberRefSet
}
}

0 comments on commit 99fbc89

Please sign in to comment.