diff --git a/core/shared/src/main/scala/zio/Fiber.scala b/core/shared/src/main/scala/zio/Fiber.scala index 17875b9a301..d93b700dcf5 100644 --- a/core/shared/src/main/scala/zio/Fiber.scala +++ b/core/shared/src/main/scala/zio/Fiber.scala @@ -530,6 +530,7 @@ object Fiber extends FiberPlatformSpecific { * '''NOTE''': This method must be invoked by the fiber itself. */ private[zio] def addChild(child: Fiber.Runtime[_, _]): Unit + private[zio] def addChildren(children: Iterable[Fiber.Runtime[_, _]]): Unit /** * Deletes the specified fiber ref. @@ -619,6 +620,7 @@ object Fiber extends FiberPlatformSpecific { * Adds a message to add a child to this fiber. */ private[zio] def tellAddChild(child: Fiber.Runtime[_, _]): Unit + private[zio] def tellAddChildren(children: Iterable[Fiber.Runtime[_, _]]): Unit /** * Adds a message to interrupt this fiber. diff --git a/core/shared/src/main/scala/zio/internal/FiberRuntime.scala b/core/shared/src/main/scala/zio/internal/FiberRuntime.scala index caac03903e2..edb7fb6fc01 100644 --- a/core/shared/src/main/scala/zio/internal/FiberRuntime.scala +++ b/core/shared/src/main/scala/zio/internal/FiberRuntime.scala @@ -87,15 +87,22 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs, ) } - def children(implicit trace: Trace): UIO[Chunk[Fiber.Runtime[_, _]]] = - ZIO.succeed { - val childs = _children - if (childs == null) Chunk.empty - else - zio.internal.Sync(childs) { - Chunk.fromJavaIterable(childs) - } + private def childrenChunk = { + //may be executed by a foreign fiber (under Sync), hence we're risking a race over the _children variable being set back to null by a concurrent transferChildren call + val childs = _children + if (childs eq null) Chunk.empty + else { + val bldr = Chunk.newBuilder[Fiber.Runtime[_, _]] + childs.forEach { child => + if ((child ne null) && child.isAlive()) + bldr.addOne(child) + } + bldr.result() } + } + + def children(implicit trace: Trace): UIO[Chunk[Fiber.Runtime[_, _]]] = + ZIO.succeed(self.childrenChunk) def fiberRefs(implicit trace: Trace): UIO[FiberRefs] = ZIO.succeed(_fiberRefs) @@ -151,14 +158,47 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs, } private[zio] def addChild(child: Fiber.Runtime[_, _]): Unit = - if (isAlive()) { - getChildren().add(child) + if (child.isAlive()) { + if (isAlive()) { + getChildren().add(child) - if (isInterrupted()) + if (isInterrupted()) + child.tellInterrupt(getInterruptedCause()) + } else { child.tellInterrupt(getInterruptedCause()) + } + } + + private[zio] def addChildren(children: Iterable[Fiber.Runtime[_, _]]): Unit = { + val iter = children.iterator + if (isAlive()) { + val childs = getChildren() + + if (isInterrupted()) { + val cause = getInterruptedCause() + while (iter.hasNext) { + val child = iter.next() + if (child.isAlive()) { + childs.add(child) + child.tellInterrupt(cause) + } + } + } else { + while (iter.hasNext) { + val child = iter.next() + if (child.isAlive()) + childs.add(child) + } + } } else { - child.tellInterrupt(getInterruptedCause()) + val cause = getInterruptedCause() + while (iter.hasNext) { + val child = iter.next() + if (child.isAlive()) + child.tellInterrupt(cause) + } } + } /** * Adds an interruptor to the set of interruptors that are interrupting this @@ -483,7 +523,8 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs, * * '''NOTE''': This method must be invoked by the fiber itself. */ - private[zio] def getChildren(): JavaSet[Fiber.Runtime[_, _]] = { + private def getChildren(): JavaSet[Fiber.Runtime[_, _]] = { + //executed by the fiber itself, no risk of racing with transferChildren if (_children eq null) { _children = Platform.newConcurrentWeakSet[Fiber.Runtime[_, _]]()(Unsafe.unsafe) } @@ -652,18 +693,29 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs, private def interruptAllChildren(): UIO[Any] = if (sendInterruptSignalToAllChildren(_children)) { val iterator = _children.iterator() - _children = null - val body = () => { - val next = iterator.next() + var curr: Fiber.Runtime[_, _] = null - if (next != null) next.await(id.location) else Exit.unit + //this finds the next operable child fiber and stores it in the `curr` variable + def skip() = { + var next: Fiber.Runtime[_, _] = null + while (iterator.hasNext && (next eq null)) { + next = iterator.next() + if ((next ne null) && !next.isAlive()) + next = null + } + curr = next } - // Now await all children to finish: - ZIO - .whileLoop(iterator.hasNext)(body())(_ => ())(id.location) + //find the first operable child fiber + //if there isn't any we can simply return null and save ourselves an effect evaluation + skip() + + if (null ne curr) { + ZIO + .whileLoop(null ne curr)(curr.await(id.location))(_ => skip())(id.location) + } else null } else null private[zio] def isAlive(): Boolean = @@ -843,7 +895,7 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs, * * '''NOTE''': This method must be invoked by the fiber itself. */ - private[zio] def removeChild(child: FiberRuntime[_, _]): Unit = + private def removeChild(child: FiberRuntime[_, _]): Unit = if (_children ne null) { _children.remove(child) () @@ -1170,16 +1222,6 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs, done } - private def sendInterruptSignalToAllChildrenConcurrently(): Boolean = { - val childFibers = _children - - if (childFibers ne null) { - internal.Sync(childFibers) { - sendInterruptSignalToAllChildren(childFibers) - } - } else false - } - private def sendInterruptSignalToAllChildren( children: JavaSet[Fiber.Runtime[_, _]] ): Boolean = @@ -1375,6 +1417,9 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs, private[zio] def tellAddChild(child: Fiber.Runtime[_, _]): Unit = tell(FiberMessage.Stateful((parentFiber, _) => parentFiber.addChild(child))) + private[zio] def tellAddChildren(children: Iterable[Fiber.Runtime[_, _]]): Unit = + tell(FiberMessage.Stateful((parentFiber, _) => parentFiber.addChildren(children))) + private[zio] def tellInterrupt(cause: Cause[Nothing]): Unit = tell(FiberMessage.InterruptSignal(cause)) @@ -1385,24 +1430,16 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs, * '''NOTE''': This method must be invoked by the fiber itself after it has * evaluated the effects but prior to exiting */ - private[zio] def transferChildren(scope: FiberScope): Unit = { - val children = _children - if ((children ne null) && !children.isEmpty) { - val iterator = children.iterator() - val flags = _runtimeFlags - - while (iterator.hasNext) { - val next = iterator.next() - - // Only move alive children. - // Unless we forked fibers and didn't await them, we shouldn't have any alive children in the set. - if ((next ne null) && next.isAlive()) { - scope.add(self, flags, next)(location, Unsafe.unsafe) - iterator.remove() - } - } + private[zio] def transferChildren(scope: FiberScope): Unit = + if ((_children ne null) && !_children.isEmpty) { + val childs = childrenChunk + //we're effectively clearing this set, seems cheaper to 'drop' it and allocate a new one if we spawn more fibers + //a concurrent children call might get the stale set, but this method (and its primary usage for dumping fibers) + //is racy by definition + _children = null + val flags = _runtimeFlags + scope.addAll(self, flags, childs)(location, Unsafe.unsafe) } - } /** * Updates a fiber ref belonging to this fiber by using the provided update diff --git a/core/shared/src/main/scala/zio/internal/FiberScope.scala b/core/shared/src/main/scala/zio/internal/FiberScope.scala index 3dddfb07e78..9a35c3ff56f 100644 --- a/core/shared/src/main/scala/zio/internal/FiberScope.scala +++ b/core/shared/src/main/scala/zio/internal/FiberScope.scala @@ -36,6 +36,15 @@ private[zio] sealed trait FiberScope { trace: Trace, unsafe: Unsafe ): Unit + + private[zio] def addAll( + currentFiber: Fiber.Runtime[_, _], + runtimeFlags: RuntimeFlags, + children: Iterable[Fiber.Runtime[_, _]] + )(implicit + trace: Trace, + unsafe: Unsafe + ): Unit } private[zio] object FiberScope { @@ -56,6 +65,20 @@ private[zio] object FiberScope { if (RuntimeFlags.fiberRoots(runtimeFlags)) { Fiber._roots.add(child) } + + private[zio] def addAll( + currentFiber: Fiber.Runtime[_, _], + runtimeFlags: RuntimeFlags, + children: Iterable[Fiber.Runtime[_, _]] + )(implicit + trace: Trace, + unsafe: Unsafe + ): Unit = + if (RuntimeFlags.fiberRoots(runtimeFlags)) { + children.foreach { + Fiber._roots.add(_) + } + } } private final class Local(val fiberId: FiberId, parentRef: WeakReference[Fiber.Runtime[_, _]]) extends FiberScope { @@ -85,6 +108,37 @@ private[zio] object FiberScope { child.tellInterrupt(Cause.interrupt(currentFiber.id)) } } + + private[zio] def addAll( + currentFiber: Fiber.Runtime[_, _], + runtimeFlags: RuntimeFlags, + children: Iterable[Fiber.Runtime[_, _]] + )(implicit + trace: Trace, + unsafe: Unsafe + ): Unit = if (children.nonEmpty) { + val parent = parentRef.get() + + if (parent ne null) { + // Parent is not GC'd. Let's check to see if the parent is the current + // fiber: + if (currentFiber eq parent) { + // The parent is the current fiber so it is safe to directly add the + // child to the parent: + parent.addChildren(children) + } else { + // The parent is not the current fiber. So we need to send a message + // to the parent so it will add the child to itself: + parent.tellAddChildren(children) + } + } else { + // Parent was GC'd. We immediately interrupt the child fiber using the id + // of the current fiber (which is adding the child to the parent): + children.foreach( + _.tellInterrupt(Cause.interrupt(currentFiber.id)) + ) + } + } } private[zio] def make(fiber: FiberRuntime[_, _]): FiberScope = diff --git a/streams-tests/shared/src/test/scala/zio/stream/ZStreamSpec.scala b/streams-tests/shared/src/test/scala/zio/stream/ZStreamSpec.scala index 7cd9642ee3e..b5713b7d403 100644 --- a/streams-tests/shared/src/test/scala/zio/stream/ZStreamSpec.scala +++ b/streams-tests/shared/src/test/scala/zio/stream/ZStreamSpec.scala @@ -2677,7 +2677,58 @@ object ZStreamSpec extends ZIOBaseSpec { val stream = ZStream.fromIterable(0 to 3).mapZIOParUnordered(10)(_ => ZIO.fail("fail")) assertZIO(stream.runDrain.exit)(fails(equalTo("fail"))) - } @@ nonFlaky @@ TestAspect.diagnose(10.seconds) + } @@ nonFlaky @@ TestAspect.diagnose(10.seconds), + test("interruption propagation") { + for { + interrupted <- Ref.make(false) + latch <- Promise.make[Nothing, Unit] + fib <- + ZStream(()) + .mapZIOParUnordered(1)(_ => (latch.succeed(()) *> ZIO.infinity).onInterrupt(interrupted.set(true))) + .runDrain + .fork + _ <- latch.await + _ <- fib.interrupt + result <- interrupted.get + } yield assert(result)(isTrue) + }, + test("interrupts pending tasks when one of the tasks fails U") { + for { + interrupted <- Ref.make(0) + latch1 <- Promise.make[Nothing, Unit] + latch2 <- Promise.make[Nothing, Unit] + result <- ZStream(1, 2, 3) + .mapZIOParUnordered(3) { + case 1 => (latch1.succeed(()) *> ZIO.never).onInterrupt(interrupted.update(_ + 1)) + case 2 => (latch2.succeed(()) *> ZIO.never).onInterrupt(interrupted.update(_ + 1)) + case 3 => latch1.await *> latch2.await *> ZIO.fail("Boom") + } + .runDrain + .exit + count <- interrupted.get + } yield assert(count)(equalTo(2)) && assert(result)(fails(equalTo("Boom"))) + } @@ nonFlaky, + test("awaits children fibers properly") { + assertZIO( + ZStream + .fromIterable((0 to 100)) + .interruptWhen(ZIO.never) + .mapZIOParUnordered(8)(_ => ZIO.succeed(1).repeatN(2000)) + .runDrain + .exit + .map(_.isInterrupted) + )(equalTo(false)) + }, + test("propagates error of original stream") { + for { + fiber <- (ZStream(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) ++ ZStream.fail(new Throwable("Boom"))) + .mapZIOParUnordered(2)(_ => ZIO.sleep(1.second)) + .runDrain + .fork + _ <- TestClock.adjust(5.seconds) + exit <- fiber.await + } yield assert(exit)(fails(hasMessage(equalTo("Boom")))) + } ), suite("mergeLeft/Right")( test("mergeLeft with HaltStrategy.Right terminates as soon as the right stream terminates") { diff --git a/streams/shared/src/main/scala/zio/stream/ZChannel.scala b/streams/shared/src/main/scala/zio/stream/ZChannel.scala index a86bca2b56c..7387fb1986f 100644 --- a/streams/shared/src/main/scala/zio/stream/ZChannel.scala +++ b/streams/shared/src/main/scala/zio/stream/ZChannel.scala @@ -4,6 +4,7 @@ import zio.{ZIO, _} import zio.stacktracer.TracingImplicits.disableAutoTrace import zio.stream.internal.{AsyncInputConsumer, AsyncInputProducer, ChannelExecutor, SingleProducerAsyncInput} import ChannelExecutor.ChannelState +import zio.stream.ZChannel.QRes /** * A `ZChannel[Env, InErr, InElem, InDone, OutErr, OutElem, OutDone]` is a nexus @@ -630,57 +631,185 @@ sealed trait ZChannel[-Env, -InErr, -InElem, -InDone, +OutErr, +OutElem, +OutDon final def mapOutZIOPar[Env1 <: Env, OutErr1 >: OutErr, OutElem2](n: Int)( f: OutElem => ZIO[Env1, OutErr1, OutElem2] )(implicit trace: Trace): ZChannel[Env1, InErr, InElem, InDone, OutErr1, OutElem2, OutDone] = - ZChannel.unwrapScopedWith { scope => - for { - input <- SingleProducerAsyncInput.make[InErr, InElem, InDone] - queueReader = ZChannel.fromInput(input) - queue <- Queue.bounded[ZIO[Env1, OutErr1, Either[OutDone, OutElem2]]](n) - _ <- scope.addFinalizer(queue.shutdown) - errorSignal <- Promise.make[OutErr1, Nothing] - permits <- Semaphore.make(n.toLong) - pull <- (queueReader >>> self).toPullIn(scope) - _ <- pull - .foldCauseZIO( - cause => queue.offer(ZIO.refailCause(cause)), - { - case Left(outDone) => - permits.withPermits(n.toLong)(ZIO.unit).interruptible *> queue.offer(ZIO.succeed(Left(outDone))) - case Right(outElem) => - for { - p <- Promise.make[OutErr1, OutElem2] - latch <- Promise.make[Nothing, Unit] - _ <- queue.offer(p.await.map(Right(_))) - _ <- permits.withPermit { - latch.succeed(()) *> - ZIO.uninterruptibleMask { restore => - restore(errorSignal.await) raceFirstAwait restore(f(outElem)) - } - .tapErrorCause(errorSignal.failCause) - .intoPromise(p) - }.forkIn(scope) - _ <- latch.await - } yield () - } - ) - .forever - .interruptible - .forkIn(scope) - } yield { - lazy val consumer: ZChannel[Env1, Any, Any, Any, OutErr1, OutElem2, OutDone] = - ZChannel.unwrap[Env1, Any, Any, Any, OutErr1, OutElem2, OutDone] { - queue.take.flatten.foldCause( - ZChannel.refailCause, - { - case Left(outDone) => ZChannel.succeedNow(outDone) - case Right(outElem) => ZChannel.write(outElem) *> consumer - } + mapOutZIOPar[Env1, OutErr1, OutElem2](n, n)(f) + + final def mapOutZIOPar[Env1 <: Env, OutErr1 >: OutErr, OutElem2](n: Int, bufferSize: Int)( + f: OutElem => ZIO[Env1, OutErr1, OutElem2] + )(implicit trace: Trace): ZChannel[Env1, InErr, InElem, InDone, OutErr1, OutElem2, OutDone] = { + val z: ZIO[Any, Nothing, ZChannel[Env1, InErr, InElem, InDone, OutErr1, OutElem2, OutDone]] = for { + input <- SingleProducerAsyncInput.make[InErr, InElem, InDone] + queueReader = ZChannel.fromInput(input) + queue <- Queue.bounded[Fiber[Option[OutErr1], OutElem2]](bufferSize) + permits <- zio.Semaphore.make(n) + failureSignal <- Promise.make[Option[OutErr1], Nothing] + outDoneSignal <- Promise.make[Nothing, OutDone] + } yield { + def forkF(a: OutElem): ZIO[Env1, Nothing, Fiber.Runtime[Option[OutErr1], OutElem2]] = ZIO.uninterruptibleMask { + restore => + for { + localScope <- zio.Scope.make + _ <- restore(permits.withPermitScoped.provideEnvironment(ZEnvironment(localScope))) + fib <- restore { + f(a).catchAllCause { c => + failureSignal.failCause(c.map(Some(_))) *> ZIO.refailCause(c.map(Some(_))) + } + .raceWith[Env1, Option[OutErr1], Option[OutErr1], Nothing, OutElem2](failureSignal.await)( + { case (leftEx, rightFib) => + rightFib.interrupt *> leftEx + }, + { case (rightEx, leftFib) => + leftFib.interrupt *> rightEx + } + ) + } + .onExit(localScope.close(_)) + .fork + } yield fib + } + + lazy val enqueueCh: ZChannel[Env1, OutErr, OutElem, OutDone, Nothing, Nothing, Any] = ZChannel + .readWithCause( + in => ZChannel.fromZIO(forkF(in).flatMap(queue.offer(_))) *> enqueueCh, + err => + ZChannel.fromZIO( + failureSignal.failCause(err.map(Some(_))) *> queue.offer(Fiber.failCause(err.map(Some(_)))) + ), + done => ZChannel.fromZIO(outDoneSignal.succeed(done) *> queue.offer(Fiber.fail(None))) + ) + + val enqueuer: ZIO[Env1 with Scope, Nothing, Fiber.Runtime[Nothing, Any]] = queueReader + .pipeTo(self) + .pipeTo(enqueueCh) + .runScoped + .forkScoped + + lazy val readerCh: ZChannel[Any, Any, Any, Any, OutErr1, OutElem2, OutDone] = + ZChannel.unwrap { + val z0: URIO[Any, ZChannel[Any, Any, Any, Any, OutErr1, OutElem2, OutDone]] = queue.take + .flatMap(_.join) + .foldCause( + c => + Cause + .flipCauseOption(c) + .map(ZChannel.refailCause(_)) + .getOrElse(ZChannel.fromZIO(outDoneSignal.await)), + out2 => ZChannel.write(out2) *> readerCh ) + z0 + } + + val resCh: ZChannel[Env1, InErr, InElem, InDone, OutErr1, OutElem2, OutDone] = ZChannel + .scoped[Env1](enqueuer) + .concatMapWith { enqueueFib => + readerCh + }((_, o) => o, (o, _) => o) + .embedInput(input) + + resCh + } + + ZChannel.unwrap(z) + } + + final def mapOutZIOParUnordered[Env1 <: Env, OutErr1 >: OutErr, OutElem2](n: Int, bufferSize: Int = 16)( + f: OutElem => ZIO[Env1, OutErr1, OutElem2] + )(implicit trace: Trace): ZChannel[Env1, InErr, InElem, InDone, OutErr1, OutElem2, OutDone] = { + val z0: ZIO[Any, Nothing, ZChannel[Env1, InErr, InElem, InDone, OutErr1, OutElem2, OutDone]] = for { + input <- SingleProducerAsyncInput.make[InErr, InElem, InDone] + queueReader = ZChannel.fromInput(input) + q <- zio.Queue.bounded[Any](bufferSize) + permits <- zio.Semaphore.make(n) + } yield { + def enqueue(a: OutElem): ZIO[Env1, Nothing, Unit] = ZIO.uninterruptibleMask { restore => + for { + localScope <- zio.Scope.make + _ <- restore(permits.withPermitScoped.provideEnvironment(ZEnvironment(localScope))) + fib <- { + restore { + val z1 = f(a) + z1 + .foldCauseZIO( + c => q.offer(QRes.failCause(c)), + a2 => q.offer(a2) + ) + } + .onExit(localScope.close(_)) + .unit + .fork } + } yield () + } - consumer.embedInput(input) + val foreachCh: ZChannel[Env1, OutErr, OutElem, OutDone, OutErr1, Nothing, OutDone] = { + lazy val proc: ZChannel[Env1, OutErr, OutElem, OutDone, OutErr1, Nothing, OutDone] = + ZChannel.readWithCause( + in => { + ZChannel.fromZIO(enqueue(in)) *> proc + }, + ZChannel.refailCause(_), + ZChannel.succeedNow(_) + ) + + proc + } + + val enqueuer = queueReader + .pipeTo(self) + .pipeTo(foreachCh) + .runScoped + .foldCauseZIO( + //this message terminates processing so it's ok for it to race with in-flight computations + c => { + q.offer(QRes.failCause(c)) + }, + done => { + //make sure this is the last message in the queue + permits + .withPermits(n) { + q.offer(QRes(done)) + } + } + ) + .forkScoped + + val reader: ZChannel[Any, Any, Any, Any, OutErr1, OutElem2, OutDone] = { + lazy val reader0: ZChannel[Any, Any, Any, Any, OutErr1, OutElem2, OutDone] = ZChannel + .fromZIO(q.take) + .flatMap { + case QRes(v) => + v match { + case c: Cause[OutErr1] @unchecked => + ZChannel.refailCause(c) + case done: OutDone @unchecked => + ZChannel.succeedNow(done) + } + case a2: OutElem2 @unchecked => + ZChannel.write(a2) *> reader0 + } + + reader0 } + + val res0 = ZChannel + .scoped[Env1](enqueuer) + .concatMapWith { fib => + reader + }( + { case (_, done) => + done + }, + { case (done, _) => + done + } + ) + .embedInput(input) + + res0 } + ZChannel.unwrap(z0) + } + /** * Returns a new channel which creates a new channel for each emitted element * and merges some of them together. Different merge strategies control what @@ -2131,4 +2260,11 @@ object ZChannel { ): ZChannel[Env1, InErr, InElem, InDone, OutErr, OutElem, OutDone] = self.provideSomeEnvironment(_.updateAt(key)(f)) } + + private case class QRes[A](value: A) extends AnyVal + + private object QRes { + val unit: QRes[Unit] = QRes(()) + def failCause[E](c: Cause[E]): QRes[Cause[E]] = QRes(c) + } } diff --git a/streams/shared/src/main/scala/zio/stream/ZPipeline.scala b/streams/shared/src/main/scala/zio/stream/ZPipeline.scala index 028bf7783d7..d3c47509c94 100644 --- a/streams/shared/src/main/scala/zio/stream/ZPipeline.scala +++ b/streams/shared/src/main/scala/zio/stream/ZPipeline.scala @@ -496,6 +496,11 @@ final class ZPipeline[-Env, +Err, -In, +Out] private ( ): ZPipeline[Env2, Err2, In, Out2] = self >>> ZPipeline.mapZIOPar(n)(f) + def mapZIOPar[Env2 <: Env, Err2 >: Err, Out2](n: => Int, bufferSize: => Int)(f: Out => ZIO[Env2, Err2, Out2])(implicit + trace: Trace + ): ZPipeline[Env2, Err2, In, Out2] = + self >>> ZPipeline.mapZIOPar(n, bufferSize)(f) + /** * Maps over elements of the stream with the specified effectful function, * executing up to `n` invocations of `f` concurrently. The element order is @@ -506,6 +511,13 @@ final class ZPipeline[-Env, +Err, -In, +Out] private ( ): ZPipeline[Env2, Err2, In, Out2] = self >>> ZPipeline.mapZIOParUnordered(n)(f) + def mapZIOParUnordered[Env2 <: Env, Err2 >: Err, Out2](n: => Int, bufferSize: => Int)( + f: Out => ZIO[Env2, Err2, Out2] + )(implicit + trace: Trace + ): ZPipeline[Env2, Err2, In, Out2] = + self >>> ZPipeline.mapZIOParUnordered(n, bufferSize)(f) + /** * Transforms the errors emitted by this pipeline using `f`. */ @@ -1799,13 +1811,18 @@ object ZPipeline extends ZPipelinePlatformSpecificConstructors { def mapZIOPar[Env, Err, In, Out](n: => Int)(f: In => ZIO[Env, Err, Out])(implicit trace: Trace ): ZPipeline[Env, Err, In, Out] = - new ZPipeline( - ZChannel - .identity[Nothing, Chunk[In], Any] - .concatMap(ZChannel.writeChunk(_)) - .mapOutZIOPar(n)(f) - .mapOut(Chunk.single) - ) + ZPipeline.fromFunction { (strm: ZStream[Any, Nothing, In]) => + strm + .mapZIOPar(n)(f) + } + + def mapZIOPar[Env, Err, In, Out](n: => Int, bufferSize: => Int)(f: In => ZIO[Env, Err, Out])(implicit + trace: Trace + ): ZPipeline[Env, Err, In, Out] = + ZPipeline.fromFunction { (strm: ZStream[Any, Nothing, In]) => + strm + .mapZIOPar(n, bufferSize)(f) + } /** * Maps over elements of the stream with the specified effectful function, @@ -1815,12 +1832,16 @@ object ZPipeline extends ZPipelinePlatformSpecificConstructors { def mapZIOParUnordered[Env, Err, In, Out](n: => Int)(f: In => ZIO[Env, Err, Out])(implicit trace: Trace ): ZPipeline[Env, Err, In, Out] = - new ZPipeline( - ZChannel - .identity[Nothing, Chunk[In], Any] - .concatMap(ZChannel.writeChunk(_)) - .mergeMap(n, 16)(in => ZStream.fromZIO(f(in)).channel) - ) + ZPipeline.fromFunction { (strm: ZStream[Any, Nothing, In]) => + strm.mapZIOParUnordered(n)(f) + } + + def mapZIOParUnordered[Env, Err, In, Out](n: => Int, bufferSize: => Int)(f: In => ZIO[Env, Err, Out])(implicit + trace: Trace + ): ZPipeline[Env, Err, In, Out] = + ZPipeline.fromFunction { (strm: ZStream[Any, Nothing, In]) => + strm.mapZIOParUnordered(n, bufferSize)(f) + } /** * Emits the provided chunk before emitting any other value. diff --git a/streams/shared/src/main/scala/zio/stream/ZStream.scala b/streams/shared/src/main/scala/zio/stream/ZStream.scala index 706b78024e5..fa5dbe6feaa 100644 --- a/streams/shared/src/main/scala/zio/stream/ZStream.scala +++ b/streams/shared/src/main/scala/zio/stream/ZStream.scala @@ -17,11 +17,11 @@ package zio.stream import zio._ -import zio.internal.{SingleThreadedRingBuffer, UniqueKey} +import zio.internal.{PartitionedRingBuffer, SingleThreadedRingBuffer, UniqueKey} import zio.metrics.MetricLabel import zio.stacktracer.TracingImplicits.disableAutoTrace import zio.stm._ -import zio.stream.ZStream.{DebounceState, HandoffSignal, zipChunks} +import zio.stream.ZStream.{DebounceState, HandoffSignal, failCause, zipChunks} import zio.stream.internal.{ZInputStream, ZReader} import java.io.{IOException, InputStream} @@ -1931,7 +1931,15 @@ final class ZStream[-R, +E, +A] private (val channel: ZChannel[R, Any, Any, Any, def mapZIOPar[R1 <: R, E1 >: E, A2](n: => Int)(f: A => ZIO[R1, E1, A2])(implicit trace: Trace ): ZStream[R1, E1, A2] = - self >>> ZPipeline.mapZIOPar(n)(f) + self.mapZIOPar[R1, E1, A2](n, n)(f) + + def mapZIOPar[R1 <: R, E1 >: E, A2](n: => Int, bufferSize: Int)(f: A => ZIO[R1, E1, A2])(implicit + trace: Trace + ): ZStream[R1, E1, A2] = + self.toChannel + .concatMap(ZChannel.writeChunk(_)) + .mapOutZIOPar[R1, E1, Chunk[A2]](n, bufferSize)(a => f(a).map(Chunk.single(_))) + .toStream /** * Maps over elements of the stream with the specified effectful function, @@ -1955,7 +1963,15 @@ final class ZStream[-R, +E, +A] private (val channel: ZChannel[R, Any, Any, Any, def mapZIOParUnordered[R1 <: R, E1 >: E, A2](n: => Int)(f: A => ZIO[R1, E1, A2])(implicit trace: Trace ): ZStream[R1, E1, A2] = - self >>> ZPipeline.mapZIOParUnordered(n)(f) + mapZIOParUnordered[R1, E1, A2](n, 16)(f) + + def mapZIOParUnordered[R1 <: R, E1 >: E, A2](n: => Int, bufferSize: => Int)(f: A => ZIO[R1, E1, A2])(implicit + trace: Trace + ): ZStream[R1, E1, A2] = + self.toChannel + .concatMap(ZChannel.writeChunk(_)) + .mapOutZIOParUnordered[R1, E1, Chunk[A2]](n, bufferSize)(a => f(a).map(Chunk.single(_))) + .toStream /** * Merges this stream and the specified stream together.