Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FiberRef (#618) #665

Merged
merged 17 commits into from
May 24, 2019
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
package scalaz.zio.internal

import java.util.concurrent.{ Executor => _, _ }
import java.util.{ WeakHashMap, Map => JMap }
import scala.concurrent.ExecutionContext
import java.util.{ Collections, WeakHashMap, Map => JMap }

import scala.concurrent.ExecutionContext
import scalaz.zio.Exit.Cause

object PlatformLive {
Expand All @@ -40,7 +40,8 @@ object PlatformLive {
System.err.println(cause.prettyPrint)

def newWeakHashMap[A, B](): JMap[A, B] =
new WeakHashMap[A, B]()
Collections.synchronizedMap(new WeakHashMap[A, B]())
hanny24 marked this conversation as resolved.
Show resolved Hide resolved

}

final def fromExecutionContext(ec: ExecutionContext): Platform =
Expand Down
20 changes: 19 additions & 1 deletion core/shared/src/main/scala/scalaz/zio/Fiber.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ trait Fiber[+E, +A] { self =>
* fiber has been determined. Attempting to join a fiber that has errored will
* result in a catchable error, _if_ that error does not result from interruption.
*/
final def join: IO[E, A] = await.flatMap(IO.done)
final def join: IO[E, A] = await.flatMap(IO.done) <* inheritFiberRefs

/**
* Interrupts the fiber with no specified reason. If the fiber has already
Expand All @@ -67,6 +67,12 @@ trait Fiber[+E, +A] { self =>
*/
def interrupt: UIO[Exit[E, A]]

/**
* Inherits values from all [[FiberRef]] instances into current fiber.
* This will resume immediately.
*/
def inheritFiberRefs: UIO[Unit]

/**
* Returns a fiber that prefers `this` fiber, but falls back to the
* `that` one when `this` one fails.
Expand All @@ -86,6 +92,9 @@ trait Fiber[+E, +A] { self =>

def interrupt: UIO[Exit[E1, A1]] =
self.interrupt *> that.interrupt

def inheritFiberRefs: UIO[Unit] =
that.inheritFiberRefs *> self.inheritFiberRefs
}

/**
Expand All @@ -105,6 +114,8 @@ trait Fiber[+E, +A] { self =>
}

def interrupt: UIO[Exit[E1, C]] = self.interrupt.zipWith(that.interrupt)(_.zipWith(_)(f, _ && _))

def inheritFiberRefs: UIO[Unit] = that.inheritFiberRefs *> self.inheritFiberRefs
}

/**
Expand Down Expand Up @@ -152,6 +163,7 @@ trait Fiber[+E, +A] { self =>
def await: UIO[Exit[E, B]] = self.await.map(_.map(f))
def poll: UIO[Option[Exit[E, B]]] = self.poll.map(_.map(_.map(f)))
def interrupt: UIO[Exit[E, B]] = self.interrupt.map(_.map(f))
def inheritFiberRefs: UIO[Unit] = self.inheritFiberRefs
}

/**
Expand Down Expand Up @@ -232,6 +244,7 @@ object Fiber {
def await: UIO[Exit[Nothing, Nothing]] = IO.never
def poll: UIO[Option[Exit[Nothing, Nothing]]] = IO.succeed(None)
def interrupt: UIO[Exit[Nothing, Nothing]] = IO.never
def inheritFiberRefs: UIO[Unit] = IO.unit
}

/**
Expand All @@ -242,6 +255,8 @@ object Fiber {
def await: UIO[Exit[E, A]] = IO.succeedLazy(exit)
def poll: UIO[Option[Exit[E, A]]] = IO.succeedLazy(Some(exit))
def interrupt: UIO[Exit[E, A]] = IO.succeedLazy(exit)
def inheritFiberRefs: UIO[Unit] = IO.unit

}

/**
Expand Down Expand Up @@ -295,5 +310,8 @@ object Fiber {
def poll: UIO[Option[Exit[Throwable, A]]] = IO.effectTotal(ftr.value.map(Exit.fromTry))

def interrupt: UIO[Exit[Throwable, A]] = join.fold(Exit.fail, Exit.succeed)

def inheritFiberRefs: UIO[Unit] = IO.unit

}
}
105 changes: 105 additions & 0 deletions core/shared/src/main/scala/scalaz/zio/FiberRef.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright 2017-2019 John A. De Goes and the ZIO Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package scalaz.zio

/**
* Fiber's counterpart for Java's `ThreadLocal`. Value is automatically propagated
* to child on fork and merged back in after joining child.
* {{{
* for {
* fiberRef <- FiberRef.make("Hello world!")
* child <- fiberRef.set("Hi!).fork
* result <- child.join
* } yield result
* }}}
*
* `result` will be equal to "Hi!" as changes done by child were merged on join.
*
* @param initial
* @tparam A
*/
final class FiberRef[A](private[zio] val initial: A) extends Serializable {
hanny24 marked this conversation as resolved.
Show resolved Hide resolved

/**
* Reads the value associated with the current fiber. Returns initial value if
* no value was `set` or inherited from parent.
*/
final val get: UIO[A] = modify(v => (v, v))

/**
* Returns an `IO` that runs with `value` bound to the current fiber.
*
* Guarantees that fiber data is properly restored via `bracket`.
*/
final def locally[R, E, B](value: A)(use: ZIO[R, E, B]): ZIO[R, E, B] =
for {
oldValue <- get
b <- {
// TODO: Dotty doesn't infer this properly
val i0: ZIO.BracketAcquire_[R, E] = set(value).bracket_[R, E]
i0(set(oldValue))(use)
}
} yield b

/**
* Atomically modifies the `FiberRef` with the specified function, which computes
* a return value for the modification. This is a more powerful version of
* `update`.
*/
final def modify[B](f: A => (B, A)): UIO[B] = new ZIO.FiberRefModify(this, f)

/**
* Atomically modifies the `FiberRef` with the specified partial function, which computes
* a return value for the modification if the function is defined in the current value
* otherwise it returns a default value.
* This is a more powerful version of `updateSome`.
*/
final def modifySome[B](default: B)(pf: PartialFunction[A, (B, A)]): UIO[B] = modify { v =>
pf.applyOrElse[A, (B, A)](v, _ => (default, v))
}

/**
* Sets the value associated with the current fiber.
*/
final def set(value: A): UIO[Unit] = modify(_ => ((), value))

/**
* Atomically modifies the `FiberRef` with the specified function.
*/
final def update(f: A => A): UIO[A] = modify { v =>
val result = f(v)
(result, result)
}

/**
* Atomically modifies the `FiberRef` with the specified partial function.
* if the function is undefined in the current value it returns the old value without changing it.
*/
final def updateSome(pf: PartialFunction[A, A]): UIO[A] = modify { v =>
val result = pf.applyOrElse[A, A](v, identity)
(result, result)
}

}

object FiberRef extends Serializable {

/**
* Creates a new `FiberRef` with given initial value.
*/
def make[A](initialValue: A): UIO[FiberRef[A]] = new ZIO.FiberRefNew(initialValue)
}
2 changes: 1 addition & 1 deletion core/shared/src/main/scala/scalaz/zio/Runtime.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ trait Runtime[+R] {
* This method is effectful and should only be invoked at the edges of your program.
*/
final def unsafeRunAsync[E, A](zio: ZIO[R, E, A])(k: Exit[E, A] => Unit): Unit = {
val context = new FiberContext[E, A](Platform, Environment.asInstanceOf[AnyRef])
val context = new FiberContext[E, A](Platform, Environment.asInstanceOf[AnyRef], Platform.newWeakHashMap())

context.evaluateNow(zio.asInstanceOf[IO[E, A]])
context.runAsync(k)
Expand Down
10 changes: 10 additions & 0 deletions core/shared/src/main/scala/scalaz/zio/ZIO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1960,6 +1960,8 @@ object ZIO extends ZIO_R_Any {
final val Access = 14
final val Provide = 15
final val SuspendWith = 16
final val FiberRefNew = 17
final val FiberRefModify = 18
hanny24 marked this conversation as resolved.
Show resolved Hide resolved
}
private[zio] final class FlatMap[R, E, A0, A](val zio: ZIO[R, E, A0], val k: A0 => ZIO[R, E, A])
extends ZIO[R, E, A] {
Expand Down Expand Up @@ -2050,4 +2052,12 @@ object ZIO extends ZIO_R_Any {
private[zio] final class SuspendWith[R, E, A](val f: Platform => ZIO[R, E, A]) extends ZIO[R, E, A] {
override def tag = Tags.SuspendWith
}

private[zio] final class FiberRefNew[A](val initialValue: A) extends UIO[FiberRef[A]] {
override def tag = Tags.FiberRefNew
}

private[zio] final class FiberRefModify[A, B](val fiberRef: FiberRef[A], val f: A => (B, A)) extends UIO[B] {
override def tag = Tags.FiberRefModify
}
}
44 changes: 38 additions & 6 deletions core/shared/src/main/scala/scalaz/zio/internal/FiberContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,16 @@ package scalaz.zio.internal

import java.util.concurrent.atomic.{ AtomicLong, AtomicReference }

import scalaz.zio._
import scalaz.zio.internal.FiberContext.FiberRefLocals
import scalaz.zio.{ UIO, _ }

import scala.annotation.{ switch, tailrec }

/**
* An implementation of Fiber that maintains context necessary for evaluation.
*/
private[zio] final class FiberContext[E, A](
platform: Platform,
startEnv: AnyRef
) extends Fiber[E, A] {
private[zio] final class FiberContext[E, A](platform: Platform, startEnv: AnyRef, fiberRefLocals: FiberRefLocals)
extends Fiber[E, A] {
import java.util.{ Collections, Set }

import FiberContext._
Expand Down Expand Up @@ -293,6 +292,24 @@ private[zio] final class FiberContext[E, A](
val io = curIo.asInstanceOf[ZIO.SuspendWith[Any, E, Any]]

curIo = io.f(platform)

case ZIO.Tags.FiberRefNew =>
val io = curIo.asInstanceOf[ZIO.FiberRefNew[Any]]

val fiberRef = new FiberRef[Any](io.initialValue)
fiberRefLocals.put(fiberRef, io.initialValue)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

synchronized is not supported in Scala.js, which we have to support.

We could push that requirement into WeakHashMap on Platform, which could return a synchronized weak hash map.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed docs for scalaz.zio.internal.Platform#newWeakHashMap, modified implementation in scalaz.zio.internal.PlatformLive

curIo = nextInstr(fiberRef)

case ZIO.Tags.FiberRefModify =>
val io = curIo.asInstanceOf[ZIO.FiberRefModify[Any, Any]]

val oldValue = Option(fiberRefLocals.get(io.fiberRef))
val (result, newValue) = io.f(oldValue.getOrElse(io.fiberRef.initial))
fiberRefLocals.put(io.fiberRef, newValue)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for above.

curIo = nextInstr(result)

}
}
} else {
Expand Down Expand Up @@ -346,7 +363,9 @@ private[zio] final class FiberContext[E, A](
* Forks an `IO` with the specified failure handler.
*/
final def fork[E, A](io: IO[E, A]): FiberContext[E, A] = {
val context = new FiberContext[E, A](platform, environment.peek())
val childFiberRefLocals: FiberRefLocals = platform.newWeakHashMap()
childFiberRefLocals.putAll(fiberRefLocals)
hanny24 marked this conversation as resolved.
Show resolved Hide resolved
val context = new FiberContext[E, A](platform, environment.peek(), childFiberRefLocals)

platform.executor.submitOrThrow(() => context.evaluateNow(io))

Expand Down Expand Up @@ -374,6 +393,17 @@ private[zio] final class FiberContext[E, A](

final def poll: UIO[Option[Exit[E, A]]] = ZIO.effectTotal(poll0)

final def inheritFiberRefs: UIO[Unit] = UIO.suspend {
import scala.collection.JavaConverters._
val locals = fiberRefLocals.asScala
if (locals.isEmpty) UIO.unit
else
UIO.foreach_(locals) {
case (fiberRef, value) =>
fiberRef.asInstanceOf[FiberRef[Any]].set(value)
}
}

private[this] final def enterSupervision: IO[E, Unit] = ZIO.effectTotal {
supervising += 1

Expand Down Expand Up @@ -569,4 +599,6 @@ private[zio] object FiberContext {

def Initial[E, A] = Executing[E, A](FiberStatus.Running, Nil)
}

type FiberRefLocals = java.util.Map[FiberRef[_], Any]
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ trait Platform { self =>
}

/**
* Creates a new java.util.WeakHashMap if supported by the platform,
* Creates a new thread safe java.util.WeakHashMap if supported by the platform,
* otherwise any implementation of Map.
*/
def newWeakHashMap[A, B](): JMap[A, B]
Expand Down