From cfb3ed6ee407126018527d7a7595d29b9312cdb7 Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Tue, 9 Apr 2024 14:10:27 +0900 Subject: [PATCH 1/2] Optimize fiber id and executor access --- .../scala/zio/concurrent/ReentrantLock.scala | 8 ++++---- core/shared/src/main/scala/zio/Fiber.scala | 4 ++-- core/shared/src/main/scala/zio/Promise.scala | 2 +- core/shared/src/main/scala/zio/Runtime.scala | 2 +- core/shared/src/main/scala/zio/ZIO.scala | 16 ++++++++-------- .../src/main/scala/zio/managed/ZManaged.scala | 6 +++--- .../zio/stream/internal/ChannelExecutor.scala | 2 +- .../src/main/scala/zio/test/Annotations.scala | 4 ++-- .../src/main/scala/zio/test/TestClock.scala | 4 ++-- 9 files changed, 24 insertions(+), 24 deletions(-) diff --git a/concurrent/shared/src/main/scala/zio/concurrent/ReentrantLock.scala b/concurrent/shared/src/main/scala/zio/concurrent/ReentrantLock.scala index 0f76b74d790..ec929b62087 100644 --- a/concurrent/shared/src/main/scala/zio/concurrent/ReentrantLock.scala +++ b/concurrent/shared/src/main/scala/zio/concurrent/ReentrantLock.scala @@ -16,7 +16,7 @@ final class ReentrantLock private (fairness: Boolean, state: Ref[ReentrantLock.S /** Queries the number of holds on this lock by the current fiber. */ lazy val holdCount: UIO[Int] = - ZIO.fiberId.flatMap { fiberId => + ZIO.fiberIdWith { fiberId => state.get.map { case State(_, Some(`fiberId`), cnt, _) => cnt case _ => 0 @@ -28,7 +28,7 @@ final class ReentrantLock private (fairness: Boolean, state: Ref[ReentrantLock.S /** Queries if this lock is held by the current fiber. */ lazy val isHeldByCurrentFiber: UIO[Boolean] = - ZIO.fiberId.flatMap { fiberId => + ZIO.fiberIdWith { fiberId => state.get.map { case State(_, Some(`fiberId`), _, _) => true case _ => false @@ -91,7 +91,7 @@ final class ReentrantLock private (fairness: Boolean, state: Ref[ReentrantLock.S * invocation. */ lazy val tryLock: UIO[Boolean] = - ZIO.fiberId.flatMap { fiberId => + ZIO.fiberIdWith { fiberId => state.modify { case State(ep, Some(`fiberId`), cnt, holders) => true -> State(ep + 1, Some(fiberId), cnt + 1, holders) @@ -110,7 +110,7 @@ final class ReentrantLock private (fairness: Boolean, state: Ref[ReentrantLock.S * the current thread is not the holder of this lock then nothing happens. */ lazy val unlock: UIO[Unit] = - ZIO.fiberId.flatMap { fiberId => + ZIO.fiberIdWith { fiberId => state.modify { case State(ep, Some(`fiberId`), 1, holders) => relock(ep, holders) diff --git a/core/shared/src/main/scala/zio/Fiber.scala b/core/shared/src/main/scala/zio/Fiber.scala index f857301a215..a58db0285cd 100644 --- a/core/shared/src/main/scala/zio/Fiber.scala +++ b/core/shared/src/main/scala/zio/Fiber.scala @@ -166,7 +166,7 @@ abstract class Fiber[+E, +A] { self => * `UIO[Exit, E, A]]` */ final def interrupt(implicit trace: Trace): UIO[Exit[E, A]] = - ZIO.fiberId.flatMap(fiberId => self.interruptAs(fiberId)) + ZIO.fiberIdWith(fiberId => self.interruptAs(fiberId)) /** * Interrupts the fiber as if interrupted from the specified fiber. If the @@ -914,7 +914,7 @@ object Fiber extends FiberPlatformSpecific { * `UIO[Unit]` */ def interruptAll(fs: Iterable[Fiber[Any, Any]])(implicit trace: Trace): UIO[Unit] = - ZIO.fiberId.flatMap(interruptAllAs(_)(fs)) + ZIO.fiberIdWith(interruptAllAs(_)(fs)) /** * Interrupts all fibers as by the specified fiber, awaiting their diff --git a/core/shared/src/main/scala/zio/Promise.scala b/core/shared/src/main/scala/zio/Promise.scala index f8693769a69..6ab17c77054 100644 --- a/core/shared/src/main/scala/zio/Promise.scala +++ b/core/shared/src/main/scala/zio/Promise.scala @@ -159,7 +159,7 @@ final class Promise[E, A] private ( * waiting on the value of the promise as by the fiber calling this method. */ def interrupt(implicit trace: Trace): UIO[Boolean] = - ZIO.fiberId.flatMap(id => completeWith(ZIO.interruptAs(id))) + ZIO.fiberIdWith(id => completeWith(ZIO.interruptAs(id))) /** * Completes the promise with interruption. This will interrupt all fibers diff --git a/core/shared/src/main/scala/zio/Runtime.scala b/core/shared/src/main/scala/zio/Runtime.scala index cc6569d1f5f..51935d6df35 100644 --- a/core/shared/src/main/scala/zio/Runtime.scala +++ b/core/shared/src/main/scala/zio/Runtime.scala @@ -48,7 +48,7 @@ trait Runtime[+R] { self => * Runs the effect "purely" through an async boundary. Useful for testing. */ final def run[E, A](zio: ZIO[R, E, A])(implicit trace: Trace): IO[E, A] = - ZIO.fiberId.flatMap { fiberId => + ZIO.fiberIdWith { fiberId => ZIO.asyncInterrupt[Any, E, A] { callback => val fiber = unsafe.fork(zio)(trace, Unsafe.unsafe) fiber.unsafe.addObserver(exit => callback(ZIO.done(exit)))(Unsafe.unsafe) diff --git a/core/shared/src/main/scala/zio/ZIO.scala b/core/shared/src/main/scala/zio/ZIO.scala index 7e320c01913..ab7a0835545 100644 --- a/core/shared/src/main/scala/zio/ZIO.scala +++ b/core/shared/src/main/scala/zio/ZIO.scala @@ -3118,13 +3118,13 @@ object ZIO extends ZIOCompanionPlatformSpecific with ZIOCompanionVersionSpecific * Retrieves the executor for this effect. */ def executor(implicit trace: Trace): UIO[Executor] = - ZIO.descriptorWith(descriptor => ZIO.succeed(descriptor.executor)) + ZIO.executorWith(ZIO.succeed(_)) /** * Constructs an effect based on the current executor. */ def executorWith[R, E, A](f: Executor => ZIO[R, E, A])(implicit trace: Trace): ZIO[R, E, A] = - ZIO.descriptorWith(descriptor => f(descriptor.executor)) + ZIO.withFiberRuntime[R, E, A]((fiberState, _) => f(fiberState.getCurrentExecutor()(Unsafe.unsafe))) /** * Determines whether any element of the `Iterable[A]` satisfies the effectual @@ -3171,7 +3171,7 @@ object ZIO extends ZIOCompanionPlatformSpecific with ZIOCompanionVersionSpecific * effect that calls this method. */ def fiberIdWith[R, E, A](f: FiberId.Runtime => ZIO[R, E, A])(implicit trace: Trace): ZIO[R, E, A] = - ZIO.descriptorWith(descriptor => f(descriptor.id)) + withFiberRuntime[R, E, A]((fiberState, _) => f(fiberState.id)) /** * Filters the collection using the specified effectual predicate. @@ -3584,9 +3584,9 @@ object ZIO extends ZIOCompanionPlatformSpecific with ZIOCompanionVersionSpecific * [[scala.concurrent.ExecutionContext]] that is backed by ZIO's own executor. */ def fromFuture[A](make: ExecutionContext => scala.concurrent.Future[A])(implicit trace: Trace): Task[A] = - ZIO.descriptorWith { d => + ZIO.executorWith { executor => import scala.util.{Failure, Success} - val ec = d.executor.asExecutionContext + val ec = executor.asExecutionContext ZIO.attempt(make(ec)).flatMap { f => val canceler: UIO[Unit] = f match { case cancelable: CancelableFuture[A] => @@ -3627,9 +3627,9 @@ object ZIO extends ZIOCompanionPlatformSpecific with ZIOCompanionVersionSpecific def fromFutureInterrupt[A]( make: ExecutionContext => scala.concurrent.Future[A] )(implicit trace: Trace): Task[A] = - ZIO.descriptorWith { d => + ZIO.executorWith { executor => import scala.util.{Failure, Success} - val ec = d.executor.asExecutionContext + val ec = executor.asExecutionContext val interrupted = new java.util.concurrent.atomic.AtomicBoolean(false) val latch = scala.concurrent.Promise[Unit]() val interruptibleEC = new scala.concurrent.ExecutionContext { @@ -3743,7 +3743,7 @@ object ZIO extends ZIOCompanionPlatformSpecific with ZIOCompanionVersionSpecific * method. */ def interrupt(implicit trace: Trace): UIO[Nothing] = - descriptorWith(descriptor => interruptAs(descriptor.id)) + fiberIdWith(interruptAs(_)) /** * Returns an effect that is interrupted as if by the specified fiber. diff --git a/managed/shared/src/main/scala/zio/managed/ZManaged.scala b/managed/shared/src/main/scala/zio/managed/ZManaged.scala index 906037943fb..5a7e65833bc 100644 --- a/managed/shared/src/main/scala/zio/managed/ZManaged.scala +++ b/managed/shared/src/main/scala/zio/managed/ZManaged.scala @@ -967,7 +967,7 @@ sealed abstract class ZManaged[-R, +E, +A] extends ZManagedVersionSpecific[R, E, a <- raceResult match { case Right(value) => ZIO.succeed(Some(value)) case Left(fiber) => - ZIO.fiberId.flatMap { id => + ZIO.fiberIdWith { id => fiber.interrupt .ensuring(innerReleaseMap.releaseAll(Exit.interrupt(id), ExecutionStrategy.Sequential)) .forkDaemon @@ -2136,7 +2136,7 @@ object ZManaged extends ZManagedPlatformSpecific { * method. */ def interrupt(implicit trace: Trace): ZManaged[Any, Nothing, Nothing] = - ZManaged.fromZIO(ZIO.descriptor).flatMap(d => failCause(Cause.interrupt(d.id))) + ZManaged.fromZIO(ZIO.fiberId).flatMap(id => failCause(Cause.interrupt(id))) /** * Returns an effect that is interrupted as if by the specified fiber. @@ -2707,7 +2707,7 @@ object ZManaged extends ZManagedPlatformSpecific { ZManaged.unwrap(Supervisor.track(true).map { supervisor => // Filter out the fiber id of whoever is calling this: ZManaged( - get(supervisor.value.flatMap(children => ZIO.descriptor.map(d => children.filter(_.id != d.id)))).zio + get(supervisor.value.flatMap(children => ZIO.fiberId.map(id => children.filter(_.id != id)))).zio .supervised(supervisor) ) }) diff --git a/streams/shared/src/main/scala/zio/stream/internal/ChannelExecutor.scala b/streams/shared/src/main/scala/zio/stream/internal/ChannelExecutor.scala index 7d3b1bbfafa..f49b9ad8598 100644 --- a/streams/shared/src/main/scala/zio/stream/internal/ChannelExecutor.scala +++ b/streams/shared/src/main/scala/zio/stream/internal/ChannelExecutor.scala @@ -841,7 +841,7 @@ private[zio] class SingleProducerAsyncInput[Err, Elem, Done]( takeWith(c => Exit.failCause(c.map(Left(_))), Exit.succeed(_), d => Exit.fail(Right(d))) def close(implicit trace: Trace): UIO[Any] = - ZIO.fiberId.flatMap(id => error(Cause.interrupt(id))) + ZIO.fiberIdWith(id => error(Cause.interrupt(id))) def awaitRead(implicit trace: Trace): UIO[Any] = ref.modify { diff --git a/test/shared/src/main/scala/zio/test/Annotations.scala b/test/shared/src/main/scala/zio/test/Annotations.scala index 6e830cb558e..29e5566c8a4 100644 --- a/test/shared/src/main/scala/zio/test/Annotations.scala +++ b/test/shared/src/main/scala/zio/test/Annotations.scala @@ -41,14 +41,14 @@ object Annotations { ): ZIO[R, TestFailure[E], TestSuccess] = zio.foldZIO(e => ref.get.map(e.annotated).flip, a => ref.get.map(a.annotated)) def supervisedFibers(implicit trace: Trace): UIO[SortedSet[Fiber.Runtime[Any, Any]]] = - ZIO.descriptorWith { descriptor => + ZIO.fiberIdWith { fiberId => get(TestAnnotation.fibers).flatMap { case Left(_) => ZIO.succeed(SortedSet.empty[Fiber.Runtime[Any, Any]]) case Right(refs) => ZIO .foreach(refs)(ref => ZIO.succeed(ref.get)) .map(_.foldLeft(SortedSet.empty[Fiber.Runtime[Any, Any]])(_ ++ _)) - .map(_.filter(_.id != descriptor.id)) + .map(_.filter(_.id != fiberId)) } } private[zio] def unsafe: UnsafeAPI = diff --git a/test/shared/src/main/scala/zio/test/TestClock.scala b/test/shared/src/main/scala/zio/test/TestClock.scala index 899fc7e4ab6..3b7ec133b11 100644 --- a/test/shared/src/main/scala/zio/test/TestClock.scala +++ b/test/shared/src/main/scala/zio/test/TestClock.scala @@ -308,14 +308,14 @@ object TestClock extends Serializable { * Returns a set of all fibers in this test. */ def supervisedFibers(implicit trace: Trace): UIO[SortedSet[Fiber.Runtime[Any, Any]]] = - ZIO.descriptorWith { descriptor => + ZIO.fiberIdWith { fiberId => annotations.get(TestAnnotation.fibers).flatMap { case Left(_) => ZIO.succeed(SortedSet.empty[Fiber.Runtime[Any, Any]]) case Right(refs) => ZIO .foreach(refs)(ref => ZIO.succeed(ref.get)) .map(_.foldLeft(SortedSet.empty[Fiber.Runtime[Any, Any]])(_ ++ _)) - .map(_.filter(_.id != descriptor.id)) + .map(_.filter(_.id != fiberId)) } } From 7902d0626a3e595e83696be6dc51dea25d6d06c4 Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Tue, 9 Apr 2024 14:35:35 +0900 Subject: [PATCH 2/2] Polish --- core/shared/src/main/scala/zio/ZIO.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/shared/src/main/scala/zio/ZIO.scala b/core/shared/src/main/scala/zio/ZIO.scala index ab7a0835545..6db76999979 100644 --- a/core/shared/src/main/scala/zio/ZIO.scala +++ b/core/shared/src/main/scala/zio/ZIO.scala @@ -3171,7 +3171,7 @@ object ZIO extends ZIOCompanionPlatformSpecific with ZIOCompanionVersionSpecific * effect that calls this method. */ def fiberIdWith[R, E, A](f: FiberId.Runtime => ZIO[R, E, A])(implicit trace: Trace): ZIO[R, E, A] = - withFiberRuntime[R, E, A]((fiberState, _) => f(fiberState.id)) + ZIO.withFiberRuntime[R, E, A]((fiberState, _) => f(fiberState.id)) /** * Filters the collection using the specified effectual predicate.