Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize fiber id and executor access [series/2.0.x] #8716

Merged
merged 2 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ final class ReentrantLock private (fairness: Boolean, state: Ref[ReentrantLock.S

/** Queries the number of holds on this lock by the current fiber. */
lazy val holdCount: UIO[Int] =
ZIO.fiberId.flatMap { fiberId =>
ZIO.fiberIdWith { fiberId =>
state.get.map {
case State(_, Some(`fiberId`), cnt, _) => cnt
case _ => 0
Expand All @@ -28,7 +28,7 @@ final class ReentrantLock private (fairness: Boolean, state: Ref[ReentrantLock.S

/** Queries if this lock is held by the current fiber. */
lazy val isHeldByCurrentFiber: UIO[Boolean] =
ZIO.fiberId.flatMap { fiberId =>
ZIO.fiberIdWith { fiberId =>
state.get.map {
case State(_, Some(`fiberId`), _, _) => true
case _ => false
Expand Down Expand Up @@ -91,7 +91,7 @@ final class ReentrantLock private (fairness: Boolean, state: Ref[ReentrantLock.S
* invocation.
*/
lazy val tryLock: UIO[Boolean] =
ZIO.fiberId.flatMap { fiberId =>
ZIO.fiberIdWith { fiberId =>
state.modify {
case State(ep, Some(`fiberId`), cnt, holders) =>
true -> State(ep + 1, Some(fiberId), cnt + 1, holders)
Expand All @@ -110,7 +110,7 @@ final class ReentrantLock private (fairness: Boolean, state: Ref[ReentrantLock.S
* the current thread is not the holder of this lock then nothing happens.
*/
lazy val unlock: UIO[Unit] =
ZIO.fiberId.flatMap { fiberId =>
ZIO.fiberIdWith { fiberId =>
state.modify {
case State(ep, Some(`fiberId`), 1, holders) =>
relock(ep, holders)
Expand Down
4 changes: 2 additions & 2 deletions core/shared/src/main/scala/zio/Fiber.scala
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ abstract class Fiber[+E, +A] { self =>
* `UIO[Exit, E, A]]`
*/
final def interrupt(implicit trace: Trace): UIO[Exit[E, A]] =
ZIO.fiberId.flatMap(fiberId => self.interruptAs(fiberId))
ZIO.fiberIdWith(fiberId => self.interruptAs(fiberId))

/**
* Interrupts the fiber as if interrupted from the specified fiber. If the
Expand Down Expand Up @@ -914,7 +914,7 @@ object Fiber extends FiberPlatformSpecific {
* `UIO[Unit]`
*/
def interruptAll(fs: Iterable[Fiber[Any, Any]])(implicit trace: Trace): UIO[Unit] =
ZIO.fiberId.flatMap(interruptAllAs(_)(fs))
ZIO.fiberIdWith(interruptAllAs(_)(fs))

/**
* Interrupts all fibers as by the specified fiber, awaiting their
Expand Down
2 changes: 1 addition & 1 deletion core/shared/src/main/scala/zio/Promise.scala
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ final class Promise[E, A] private (
* waiting on the value of the promise as by the fiber calling this method.
*/
def interrupt(implicit trace: Trace): UIO[Boolean] =
ZIO.fiberId.flatMap(id => completeWith(ZIO.interruptAs(id)))
ZIO.fiberIdWith(id => completeWith(ZIO.interruptAs(id)))

/**
* Completes the promise with interruption. This will interrupt all fibers
Expand Down
2 changes: 1 addition & 1 deletion core/shared/src/main/scala/zio/Runtime.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ trait Runtime[+R] { self =>
* Runs the effect "purely" through an async boundary. Useful for testing.
*/
final def run[E, A](zio: ZIO[R, E, A])(implicit trace: Trace): IO[E, A] =
ZIO.fiberId.flatMap { fiberId =>
ZIO.fiberIdWith { fiberId =>
ZIO.asyncInterrupt[Any, E, A] { callback =>
val fiber = unsafe.fork(zio)(trace, Unsafe.unsafe)
fiber.unsafe.addObserver(exit => callback(ZIO.done(exit)))(Unsafe.unsafe)
Expand Down
16 changes: 8 additions & 8 deletions core/shared/src/main/scala/zio/ZIO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3118,13 +3118,13 @@ object ZIO extends ZIOCompanionPlatformSpecific with ZIOCompanionVersionSpecific
* Retrieves the executor for this effect.
*/
def executor(implicit trace: Trace): UIO[Executor] =
ZIO.descriptorWith(descriptor => ZIO.succeed(descriptor.executor))
ZIO.executorWith(ZIO.succeed(_))

/**
* Constructs an effect based on the current executor.
*/
def executorWith[R, E, A](f: Executor => ZIO[R, E, A])(implicit trace: Trace): ZIO[R, E, A] =
ZIO.descriptorWith(descriptor => f(descriptor.executor))
ZIO.withFiberRuntime[R, E, A]((fiberState, _) => f(fiberState.getCurrentExecutor()(Unsafe.unsafe)))

/**
* Determines whether any element of the `Iterable[A]` satisfies the effectual
Expand Down Expand Up @@ -3171,7 +3171,7 @@ object ZIO extends ZIOCompanionPlatformSpecific with ZIOCompanionVersionSpecific
* effect that calls this method.
*/
def fiberIdWith[R, E, A](f: FiberId.Runtime => ZIO[R, E, A])(implicit trace: Trace): ZIO[R, E, A] =
ZIO.descriptorWith(descriptor => f(descriptor.id))
ZIO.withFiberRuntime[R, E, A]((fiberState, _) => f(fiberState.id))

/**
* Filters the collection using the specified effectual predicate.
Expand Down Expand Up @@ -3584,9 +3584,9 @@ object ZIO extends ZIOCompanionPlatformSpecific with ZIOCompanionVersionSpecific
* [[scala.concurrent.ExecutionContext]] that is backed by ZIO's own executor.
*/
def fromFuture[A](make: ExecutionContext => scala.concurrent.Future[A])(implicit trace: Trace): Task[A] =
ZIO.descriptorWith { d =>
ZIO.executorWith { executor =>
import scala.util.{Failure, Success}
val ec = d.executor.asExecutionContext
val ec = executor.asExecutionContext
ZIO.attempt(make(ec)).flatMap { f =>
val canceler: UIO[Unit] = f match {
case cancelable: CancelableFuture[A] =>
Expand Down Expand Up @@ -3627,9 +3627,9 @@ object ZIO extends ZIOCompanionPlatformSpecific with ZIOCompanionVersionSpecific
def fromFutureInterrupt[A](
make: ExecutionContext => scala.concurrent.Future[A]
)(implicit trace: Trace): Task[A] =
ZIO.descriptorWith { d =>
ZIO.executorWith { executor =>
import scala.util.{Failure, Success}
val ec = d.executor.asExecutionContext
val ec = executor.asExecutionContext
val interrupted = new java.util.concurrent.atomic.AtomicBoolean(false)
val latch = scala.concurrent.Promise[Unit]()
val interruptibleEC = new scala.concurrent.ExecutionContext {
Expand Down Expand Up @@ -3743,7 +3743,7 @@ object ZIO extends ZIOCompanionPlatformSpecific with ZIOCompanionVersionSpecific
* method.
*/
def interrupt(implicit trace: Trace): UIO[Nothing] =
descriptorWith(descriptor => interruptAs(descriptor.id))
fiberIdWith(interruptAs(_))

/**
* Returns an effect that is interrupted as if by the specified fiber.
Expand Down
6 changes: 3 additions & 3 deletions managed/shared/src/main/scala/zio/managed/ZManaged.scala
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ sealed abstract class ZManaged[-R, +E, +A] extends ZManagedVersionSpecific[R, E,
a <- raceResult match {
case Right(value) => ZIO.succeed(Some(value))
case Left(fiber) =>
ZIO.fiberId.flatMap { id =>
ZIO.fiberIdWith { id =>
fiber.interrupt
.ensuring(innerReleaseMap.releaseAll(Exit.interrupt(id), ExecutionStrategy.Sequential))
.forkDaemon
Expand Down Expand Up @@ -2136,7 +2136,7 @@ object ZManaged extends ZManagedPlatformSpecific {
* method.
*/
def interrupt(implicit trace: Trace): ZManaged[Any, Nothing, Nothing] =
ZManaged.fromZIO(ZIO.descriptor).flatMap(d => failCause(Cause.interrupt(d.id)))
ZManaged.fromZIO(ZIO.fiberId).flatMap(id => failCause(Cause.interrupt(id)))

/**
* Returns an effect that is interrupted as if by the specified fiber.
Expand Down Expand Up @@ -2707,7 +2707,7 @@ object ZManaged extends ZManagedPlatformSpecific {
ZManaged.unwrap(Supervisor.track(true).map { supervisor =>
// Filter out the fiber id of whoever is calling this:
ZManaged(
get(supervisor.value.flatMap(children => ZIO.descriptor.map(d => children.filter(_.id != d.id)))).zio
get(supervisor.value.flatMap(children => ZIO.fiberId.map(id => children.filter(_.id != id)))).zio
.supervised(supervisor)
)
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ private[zio] class SingleProducerAsyncInput[Err, Elem, Done](
takeWith(c => Exit.failCause(c.map(Left(_))), Exit.succeed(_), d => Exit.fail(Right(d)))

def close(implicit trace: Trace): UIO[Any] =
ZIO.fiberId.flatMap(id => error(Cause.interrupt(id)))
ZIO.fiberIdWith(id => error(Cause.interrupt(id)))

def awaitRead(implicit trace: Trace): UIO[Any] =
ref.modify {
Expand Down
4 changes: 2 additions & 2 deletions test/shared/src/main/scala/zio/test/Annotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ object Annotations {
): ZIO[R, TestFailure[E], TestSuccess] =
zio.foldZIO(e => ref.get.map(e.annotated).flip, a => ref.get.map(a.annotated))
def supervisedFibers(implicit trace: Trace): UIO[SortedSet[Fiber.Runtime[Any, Any]]] =
ZIO.descriptorWith { descriptor =>
ZIO.fiberIdWith { fiberId =>
get(TestAnnotation.fibers).flatMap {
case Left(_) => ZIO.succeed(SortedSet.empty[Fiber.Runtime[Any, Any]])
case Right(refs) =>
ZIO
.foreach(refs)(ref => ZIO.succeed(ref.get))
.map(_.foldLeft(SortedSet.empty[Fiber.Runtime[Any, Any]])(_ ++ _))
.map(_.filter(_.id != descriptor.id))
.map(_.filter(_.id != fiberId))
}
}
private[zio] def unsafe: UnsafeAPI =
Expand Down
4 changes: 2 additions & 2 deletions test/shared/src/main/scala/zio/test/TestClock.scala
Original file line number Diff line number Diff line change
Expand Up @@ -308,14 +308,14 @@ object TestClock extends Serializable {
* Returns a set of all fibers in this test.
*/
def supervisedFibers(implicit trace: Trace): UIO[SortedSet[Fiber.Runtime[Any, Any]]] =
ZIO.descriptorWith { descriptor =>
ZIO.fiberIdWith { fiberId =>
annotations.get(TestAnnotation.fibers).flatMap {
case Left(_) => ZIO.succeed(SortedSet.empty[Fiber.Runtime[Any, Any]])
case Right(refs) =>
ZIO
.foreach(refs)(ref => ZIO.succeed(ref.get))
.map(_.foldLeft(SortedSet.empty[Fiber.Runtime[Any, Any]])(_ ++ _))
.map(_.filter(_.id != descriptor.id))
.map(_.filter(_.id != fiberId))
}
}

Expand Down