Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize ZStream's mapZIOPar and mapZIOParUnordered #8819

Merged
merged 36 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
411f287
strm_mapZioPar_opt: introduce mapZioPar and mapZioParUnordered 'direc…
eyalfa Apr 26, 2024
6efc2d7
strm_mapZioPar_opt: stream.mapZIOParUnordered
eyalfa Apr 26, 2024
d911350
strm_mapZioPar_opt: avoid mapZIOPar atm, add tests for mapZIOParUnord…
eyalfa Apr 28, 2024
59533ad
Merge branch 'series/2.x' into strm_mapZioPar_opt
eyalfa Apr 28, 2024
277bf35
strm_mapZioPar_opt: mapZIOPar2
eyalfa Apr 28, 2024
58cabb4
strm_mapZioPar_opt: use queue of Exit rather than a queue of Take
eyalfa Apr 30, 2024
0c31b9d
strm_mapZioPar_opt: strm.mapZIOParUnordered, run upstream in a scoped…
eyalfa May 1, 2024
b7cb92e
strm_mapZioPar_opt__fiberChildren: make fiber's children collection f…
eyalfa May 1, 2024
6596142
strm_mapZioPar_opt: mapZIOParUnordered2, record counting impl
eyalfa May 3, 2024
560ebcd
strm_mapZioPar_opt: strm.mapZIOParUnordered, ditch counting, rely on …
eyalfa May 3, 2024
b9eb3ac
strm_mapZioPar_opt: make sure strm.mapZIOParUnordered doesn't break w…
eyalfa May 3, 2024
42dfe12
Merge branch 'strm_mapZioPar_opt' into strm_mapZioPar_opt__fiberChildren
eyalfa May 3, 2024
6ad59d8
strm_mapZioPar_opt__fiberChildren: strm.mapZIOPar, use queue of fiber…
eyalfa May 5, 2024
7ad415b
strm_mapZioPar_opt__fiberChildren: fix issues in mapZIOPar
eyalfa May 5, 2024
fe66a4d
strm_mapZioPar_opt__chhannel: introduce a channel.mapZIOPar implement…
eyalfa May 5, 2024
1e2bfd8
strm_mapZioPar_opt__chhannel: duplicate the impls to the channel level
eyalfa May 6, 2024
8118984
strm_mapZioPar_opt__chhannel: add stream+pl methods for mapZIOPar and…
eyalfa May 6, 2024
077898f
Merge branch 'series/2.x' into strm_mapZioPar_opt
eyalfa May 6, 2024
5bbfe03
strm_mapZioPar_opt: fmt
eyalfa May 6, 2024
b228693
strm_mapZioPar_opt: remove ead code, add default prm
eyalfa May 6, 2024
e8cc44f
strm_mapZioPar_opt: fix CI issues
eyalfa May 6, 2024
d6c9793
strm_mapZioPar_opt: fix scala3 compilation issue
eyalfa May 6, 2024
6cf8265
strm_mapZioPar_opt: address reviw comments
eyalfa May 7, 2024
918aa6c
strm_mapZioPar_opt: address one more review comment
eyalfa May 7, 2024
11e09e1
Update core/shared/src/main/scala/zio/internal/FiberRuntime.scala
eyalfa May 7, 2024
08d6510
strm_mapZioPar_opt: slight compilation fix
eyalfa May 7, 2024
a7246ad
strm_mapZioPar_opt: make QRes an AnyVal
eyalfa May 8, 2024
552a6b4
Merge branch 'series/2.x' into strm_mapZioPar_opt
eyalfa May 8, 2024
6f81f82
strm_mapZioPar_opt__childrenSet: directly expose the children set to …
eyalfa May 10, 2024
9fd7ff3
strm_mapZioPar_opt__childrenSet: optimize transferChildren by batchin…
eyalfa May 10, 2024
3bdea62
strm_mapZioPar_opt__childrenSet: fmt
eyalfa May 10, 2024
d921095
Merge branch 'series/2.x' into strm_mapZioPar_opt
eyalfa May 10, 2024
eb1ef45
strm_mapZioPar_opt: FiberRuntime.children, extra safety measure to pr…
eyalfa May 10, 2024
ba2e2ef
Merge branch 'series/2.x' into strm_mapZioPar_opt
eyalfa May 12, 2024
b8d2dbd
strm_mapZioPar_opt: fiber runtime, back to synchronized child fibers set
eyalfa May 12, 2024
6edd498
Merge branch 'series/2.x' into strm_mapZioPar_opt
eyalfa May 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions core/shared/src/main/scala/zio/Fiber.scala
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ object Fiber extends FiberPlatformSpecific {
* '''NOTE''': This method must be invoked by the fiber itself.
*/
private[zio] def addChild(child: Fiber.Runtime[_, _]): Unit
private[zio] def addChildren(children: Iterable[Fiber.Runtime[_, _]]): Unit

/**
* Deletes the specified fiber ref.
Expand Down Expand Up @@ -619,6 +620,7 @@ object Fiber extends FiberPlatformSpecific {
* Adds a message to add a child to this fiber.
*/
private[zio] def tellAddChild(child: Fiber.Runtime[_, _]): Unit
private[zio] def tellAddChildren(children: Iterable[Fiber.Runtime[_, _]]): Unit

/**
* Adds a message to interrupt this fiber.
Expand Down
133 changes: 85 additions & 48 deletions core/shared/src/main/scala/zio/internal/FiberRuntime.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,22 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs,
)
}

def children(implicit trace: Trace): UIO[Chunk[Fiber.Runtime[_, _]]] =
ZIO.succeed {
val childs = _children
if (childs == null) Chunk.empty
else
zio.internal.Sync(childs) {
Chunk.fromJavaIterable(childs)
}
private def childrenChunk = {
//may be executed by a foreign fiber (under Sync), hence we're risking a race over the _children variable being set back to null by a concurrent transferChildren call
val childs = _children
if (childs eq null) Chunk.empty
else {
val bldr = Chunk.newBuilder[Fiber.Runtime[_, _]]
childs.forEach { child =>
if ((child ne null) && child.isAlive())
bldr.addOne(child)
}
bldr.result()
}
}
eyalfa marked this conversation as resolved.
Show resolved Hide resolved

def children(implicit trace: Trace): UIO[Chunk[Fiber.Runtime[_, _]]] =
ZIO.succeed(self.childrenChunk)

def fiberRefs(implicit trace: Trace): UIO[FiberRefs] = ZIO.succeed(_fiberRefs)

Expand Down Expand Up @@ -151,14 +158,47 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs,
}

private[zio] def addChild(child: Fiber.Runtime[_, _]): Unit =
if (isAlive()) {
getChildren().add(child)
if (child.isAlive()) {
if (isAlive()) {
getChildren().add(child)

if (isInterrupted())
if (isInterrupted())
child.tellInterrupt(getInterruptedCause())
} else {
child.tellInterrupt(getInterruptedCause())
}
}

private[zio] def addChildren(children: Iterable[Fiber.Runtime[_, _]]): Unit = {
val iter = children.iterator
if (isAlive()) {
val childs = getChildren()

if (isInterrupted()) {
val cause = getInterruptedCause()
while (iter.hasNext) {
val child = iter.next()
if (child.isAlive()) {
childs.add(child)
child.tellInterrupt(cause)
}
}
} else {
while (iter.hasNext) {
val child = iter.next()
if (child.isAlive())
childs.add(child)
}
}
} else {
child.tellInterrupt(getInterruptedCause())
val cause = getInterruptedCause()
while (iter.hasNext) {
val child = iter.next()
if (child.isAlive())
child.tellInterrupt(cause)
}
}
}

/**
* Adds an interruptor to the set of interruptors that are interrupting this
Expand Down Expand Up @@ -483,7 +523,8 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs,
*
* '''NOTE''': This method must be invoked by the fiber itself.
*/
private[zio] def getChildren(): JavaSet[Fiber.Runtime[_, _]] = {
private def getChildren(): JavaSet[Fiber.Runtime[_, _]] = {
//executed by the fiber itself, no risk of racing with transferChildren
if (_children eq null) {
_children = Platform.newConcurrentWeakSet[Fiber.Runtime[_, _]]()(Unsafe.unsafe)
}
Expand Down Expand Up @@ -652,18 +693,29 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs,
private def interruptAllChildren(): UIO[Any] =
if (sendInterruptSignalToAllChildren(_children)) {
val iterator = _children.iterator()

_children = null

val body = () => {
val next = iterator.next()
var curr: Fiber.Runtime[_, _] = null

if (next != null) next.await(id.location) else Exit.unit
//this finds the next operable child fiber and stores it in the `curr` variable
def skip() = {
var next: Fiber.Runtime[_, _] = null
while (iterator.hasNext && (next eq null)) {
next = iterator.next()
if ((next ne null) && !next.isAlive())
next = null
}
curr = next
}

// Now await all children to finish:
ZIO
.whileLoop(iterator.hasNext)(body())(_ => ())(id.location)
//find the first operable child fiber
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

//if there isn't any we can simply return null and save ourselves an effect evaluation
skip()

if (null ne curr) {
ZIO
.whileLoop(null ne curr)(curr.await(id.location))(_ => skip())(id.location)
} else null
} else null

private[zio] def isAlive(): Boolean =
Expand Down Expand Up @@ -843,7 +895,7 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs,
*
* '''NOTE''': This method must be invoked by the fiber itself.
*/
private[zio] def removeChild(child: FiberRuntime[_, _]): Unit =
private def removeChild(child: FiberRuntime[_, _]): Unit =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

if (_children ne null) {
_children.remove(child)
()
Expand Down Expand Up @@ -1170,16 +1222,6 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs,
done
}

private def sendInterruptSignalToAllChildrenConcurrently(): Boolean = {
val childFibers = _children

if (childFibers ne null) {
internal.Sync(childFibers) {
sendInterruptSignalToAllChildren(childFibers)
}
} else false
}

private def sendInterruptSignalToAllChildren(
children: JavaSet[Fiber.Runtime[_, _]]
): Boolean =
Expand Down Expand Up @@ -1375,6 +1417,9 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs,
private[zio] def tellAddChild(child: Fiber.Runtime[_, _]): Unit =
tell(FiberMessage.Stateful((parentFiber, _) => parentFiber.addChild(child)))

private[zio] def tellAddChildren(children: Iterable[Fiber.Runtime[_, _]]): Unit =
tell(FiberMessage.Stateful((parentFiber, _) => parentFiber.addChildren(children)))

private[zio] def tellInterrupt(cause: Cause[Nothing]): Unit =
tell(FiberMessage.InterruptSignal(cause))

Expand All @@ -1385,24 +1430,16 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs,
* '''NOTE''': This method must be invoked by the fiber itself after it has
* evaluated the effects but prior to exiting
*/
private[zio] def transferChildren(scope: FiberScope): Unit = {
val children = _children
if ((children ne null) && !children.isEmpty) {
val iterator = children.iterator()
val flags = _runtimeFlags

while (iterator.hasNext) {
val next = iterator.next()

// Only move alive children.
// Unless we forked fibers and didn't await them, we shouldn't have any alive children in the set.
if ((next ne null) && next.isAlive()) {
scope.add(self, flags, next)(location, Unsafe.unsafe)
iterator.remove()
}
}
private[zio] def transferChildren(scope: FiberScope): Unit =
if ((_children ne null) && !_children.isEmpty) {
val childs = childrenChunk
//we're effectively clearing this set, seems cheaper to 'drop' it and allocate a new one if we spawn more fibers
//a concurrent children call might get the stale set, but this method (and its primary usage for dumping fibers)
//is racy by definition
_children = null
val flags = _runtimeFlags
scope.addAll(self, flags, childs)(location, Unsafe.unsafe)
}
}

/**
* Updates a fiber ref belonging to this fiber by using the provided update
Expand Down
54 changes: 54 additions & 0 deletions core/shared/src/main/scala/zio/internal/FiberScope.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ private[zio] sealed trait FiberScope {
trace: Trace,
unsafe: Unsafe
): Unit

private[zio] def addAll(
currentFiber: Fiber.Runtime[_, _],
runtimeFlags: RuntimeFlags,
children: Iterable[Fiber.Runtime[_, _]]
)(implicit
trace: Trace,
unsafe: Unsafe
): Unit
}

private[zio] object FiberScope {
Expand All @@ -56,6 +65,20 @@ private[zio] object FiberScope {
if (RuntimeFlags.fiberRoots(runtimeFlags)) {
Fiber._roots.add(child)
}

private[zio] def addAll(
currentFiber: Fiber.Runtime[_, _],
runtimeFlags: RuntimeFlags,
children: Iterable[Fiber.Runtime[_, _]]
)(implicit
trace: Trace,
unsafe: Unsafe
): Unit =
if (RuntimeFlags.fiberRoots(runtimeFlags)) {
children.foreach {
Fiber._roots.add(_)
}
}
}

private final class Local(val fiberId: FiberId, parentRef: WeakReference[Fiber.Runtime[_, _]]) extends FiberScope {
Expand Down Expand Up @@ -85,6 +108,37 @@ private[zio] object FiberScope {
child.tellInterrupt(Cause.interrupt(currentFiber.id))
}
}

private[zio] def addAll(
currentFiber: Fiber.Runtime[_, _],
runtimeFlags: RuntimeFlags,
children: Iterable[Fiber.Runtime[_, _]]
)(implicit
trace: Trace,
unsafe: Unsafe
): Unit = if (children.nonEmpty) {
val parent = parentRef.get()

if (parent ne null) {
// Parent is not GC'd. Let's check to see if the parent is the current
// fiber:
if (currentFiber eq parent) {
// The parent is the current fiber so it is safe to directly add the
// child to the parent:
parent.addChildren(children)
} else {
// The parent is not the current fiber. So we need to send a message
// to the parent so it will add the child to itself:
parent.tellAddChildren(children)
}
} else {
// Parent was GC'd. We immediately interrupt the child fiber using the id
// of the current fiber (which is adding the child to the parent):
children.foreach(
_.tellInterrupt(Cause.interrupt(currentFiber.id))
)
}
}
}

private[zio] def make(fiber: FiberRuntime[_, _]): FiberScope =
Expand Down
53 changes: 52 additions & 1 deletion streams-tests/shared/src/test/scala/zio/stream/ZStreamSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2677,7 +2677,58 @@ object ZStreamSpec extends ZIOBaseSpec {
val stream =
ZStream.fromIterable(0 to 3).mapZIOParUnordered(10)(_ => ZIO.fail("fail"))
assertZIO(stream.runDrain.exit)(fails(equalTo("fail")))
} @@ nonFlaky @@ TestAspect.diagnose(10.seconds)
} @@ nonFlaky @@ TestAspect.diagnose(10.seconds),
test("interruption propagation") {
for {
interrupted <- Ref.make(false)
latch <- Promise.make[Nothing, Unit]
fib <-
ZStream(())
.mapZIOParUnordered(1)(_ => (latch.succeed(()) *> ZIO.infinity).onInterrupt(interrupted.set(true)))
.runDrain
.fork
_ <- latch.await
_ <- fib.interrupt
result <- interrupted.get
} yield assert(result)(isTrue)
},
test("interrupts pending tasks when one of the tasks fails U") {
for {
interrupted <- Ref.make(0)
latch1 <- Promise.make[Nothing, Unit]
latch2 <- Promise.make[Nothing, Unit]
result <- ZStream(1, 2, 3)
.mapZIOParUnordered(3) {
case 1 => (latch1.succeed(()) *> ZIO.never).onInterrupt(interrupted.update(_ + 1))
case 2 => (latch2.succeed(()) *> ZIO.never).onInterrupt(interrupted.update(_ + 1))
case 3 => latch1.await *> latch2.await *> ZIO.fail("Boom")
}
.runDrain
.exit
count <- interrupted.get
} yield assert(count)(equalTo(2)) && assert(result)(fails(equalTo("Boom")))
} @@ nonFlaky,
test("awaits children fibers properly") {
assertZIO(
ZStream
.fromIterable((0 to 100))
.interruptWhen(ZIO.never)
.mapZIOParUnordered(8)(_ => ZIO.succeed(1).repeatN(2000))
.runDrain
.exit
.map(_.isInterrupted)
)(equalTo(false))
},
test("propagates error of original stream") {
for {
fiber <- (ZStream(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) ++ ZStream.fail(new Throwable("Boom")))
.mapZIOParUnordered(2)(_ => ZIO.sleep(1.second))
.runDrain
.fork
_ <- TestClock.adjust(5.seconds)
exit <- fiber.await
} yield assert(exit)(fails(hasMessage(equalTo("Boom"))))
}
),
suite("mergeLeft/Right")(
test("mergeLeft with HaltStrategy.Right terminates as soon as the right stream terminates") {
Expand Down
Loading
Loading