Skip to content

Commit

Permalink
feat: rewrite subscription and demandUnfoldSink to not require unsafe…
Browse files Browse the repository at this point in the history
…Run (#287)

* feat: no unsafeRun in demand tracking

* feat: track demand only in subscription

* refactorings

* use AtomicReference to fix race conditions

* fix: hide internals of DemandTrackingSubscription

* refactoring

* fix StackOverfow on 2.11

* fix: defer side effect until after AtomicRef.getAndUpdate

* revert nonFlaky annotations
  • Loading branch information
runtologist committed Jan 4, 2022
1 parent ecf6305 commit a4359e9
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 46 deletions.
125 changes: 82 additions & 43 deletions src/main/scala/zio/interop/reactivestreams/Adapters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import zio.stream._
import zio.stream.ZStream.Pull

import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicReference

object Adapters {

Expand All @@ -19,12 +20,12 @@ object Adapters {
if (subscriber == null) {
throw new NullPointerException("Subscriber must not be null.")
} else {
val subscription = new DemandTrackingSubscription(subscriber)
runtime.unsafeRunAsync(
for {
demand <- Queue.unbounded[Long]
_ <- UIO(subscriber.onSubscribe(createSubscription(subscriber, demand, runtime)))
_ <- UIO(subscriber.onSubscribe(subscription))
_ <- stream
.run(demandUnfoldSink(subscriber, demand))
.run(demandUnfoldSink(subscriber, subscription))
.catchAll(e => UIO(subscriber.onError(e)))
.forkDaemon
} yield ()
Expand All @@ -37,17 +38,11 @@ object Adapters {
)(implicit trace: ZTraceElement): ZManaged[Any, Nothing, (E => UIO[Unit], ZSink[Any, Nothing, I, I, Unit])] = {
val sub = subscriber
for {
runtime <- ZIO.runtime[Any].toManaged
demand <- Queue.unbounded[Long].toManaged
subscription = createSubscription(sub, demand, runtime)
_ <- ZManaged.succeed(sub.onSubscribe(subscription))
errorSignaled <- Promise.makeManaged[Nothing, Boolean]
} yield {
val signalError =
(e: E) => ZIO.whenZIO(errorSignaled.complete(UIO.succeedNow(true)))(UIO(sub.onError(e)) *> demand.shutdown).unit

(signalError, demandUnfoldSink(sub, demand))
}
error <- Promise.makeManaged[E, Nothing]
subscription = new DemandTrackingSubscription(sub)
_ <- ZManaged.succeed(sub.onSubscribe(subscription))
_ <- error.await.catchAll(t => UIO(sub.onError(t))).toManaged.fork
} yield (error.fail(_).unit, demandUnfoldSink(sub, subscription))
}

def publisherToStream[O](
Expand Down Expand Up @@ -215,41 +210,85 @@ object Adapters {
}

private def demandUnfoldSink[I](
subscriber: Subscriber[_ >: I],
demand: Queue[Long]
subscriber: Subscriber[I],
subscription: DemandTrackingSubscription
): ZSink[Any, Nothing, I, I, Unit] =
ZSink
.foldChunksZIO[Any, Nothing, I, Long](0L)(_ >= 0L) { (bufferedDemand, chunk) =>
UIO
.iterate((chunk, bufferedDemand))(!_._1.isEmpty) { case (chunk, bufferedDemand) =>
demand.isShutdown.flatMap {
case true => UIO((Chunk.empty, -1))
case false =>
if (chunk.size.toLong <= bufferedDemand)
UIO
.foreachDiscard(chunk)(a => UIO(subscriber.onNext(a)))
.as((Chunk.empty, bufferedDemand - chunk.size.toLong))
else
UIO.foreachDiscard(chunk.take(bufferedDemand.toInt))(a => UIO(subscriber.onNext(a))) *>
demand.take.map((chunk.drop(bufferedDemand.toInt), _))
}
.foldChunksZIO[Any, Nothing, I, Boolean](true)(identity) { (_, chunk) =>
IO
.iterate(chunk)(!_.isEmpty) { chunk =>
subscription
.offer(chunk.size)
.flatMap { acceptedCount =>
UIO.foreach(chunk.take(acceptedCount))(a => UIO(subscriber.onNext(a))).as(chunk.drop(acceptedCount))
}
}
.map(_._2)
.fold(
_ => false, // canceled
_ => true
)
}
.map(_ => if (!subscription.isCanceled) subscriber.onComplete())

private class DemandTrackingSubscription(subscriber: Subscriber[_]) extends Subscription {

private case class State(
requestedCount: Long, // -1 when cancelled
toNotify: Option[(Int, Promise[Unit, Int])]
)

private val initial = State(0L, None)
private val canceled = State(-1, None)
private def requested(n: Long) = State(n, None)
private def awaiting(n: Int, p: Promise[Unit, Int]) = State(0L, Some((n, p)))

private val state = new AtomicReference(initial)

def offer(n: Int): IO[Unit, Int] = {
var result: IO[Unit, Int] = null
state.updateAndGet {
case `canceled` =>
result = IO.fail(())
canceled
case State(0L, _) =>
val p = Promise.unsafeMake[Unit, Int](FiberId.None)
result = p.await
awaiting(n, p)
case State(requestedCount, _) =>
val newRequestedCount = Math.max(requestedCount - n, 0L)
val accepted = Math.min(requestedCount, n.toLong).toInt
result = IO.succeedNow(accepted)
requested(newRequestedCount)
}
.mapZIO(_ => demand.isShutdown.flatMap(is => UIO(subscriber.onComplete()).when(!is).unit))

private def createSubscription[A](
subscriber: Subscriber[_ >: A],
demand: Queue[Long],
runtime: Runtime[_]
): Subscription =
new Subscription {
override def request(n: Long): Unit =
if (n <= 0) subscriber.onError(new IllegalArgumentException("non-positive subscription request"))
else runtime.unsafeRunAsync(demand.offer(n))
override def cancel(): Unit = runtime.unsafeRun(demand.shutdown)
result
}

def isCanceled: Boolean = state.get().requestedCount < 0

override def request(n: Long): Unit = {
if (n <= 0) subscriber.onError(new IllegalArgumentException("non-positive subscription request"))
var notification: () => Unit = () => ()
state.getAndUpdate {
case `canceled` =>
canceled
case State(requestedCount, Some((offered, toNotify))) =>
val newRequestedCount = requestedCount + n
val accepted = Math.min(offered.toLong, newRequestedCount)
val remaining = newRequestedCount - accepted
notification = () => toNotify.unsafeDone(IO.succeedNow(accepted.toInt))
requested(remaining)
case State(requestedCount, _) if ((Long.MaxValue - n) > requestedCount) =>
requested(requestedCount + n)
case _ =>
requested(Long.MaxValue)
}
notification()
}

override def cancel(): Unit =
state.getAndSet(canceled).toNotify.foreach { case (_, p) => p.unsafeDone(IO.fail(())) }
}

private def fromPull[R, E, A](zio: ZManaged[R, Nothing, ZIO[R, Option[E], Chunk[A]]])(implicit
trace: ZTraceElement
): ZStream[R, E, A] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ object StreamToPublisherSpec extends DefaultRunnableSpec {
r <- Task
.attemptBlockingInterrupt(method.invoke(pv))
.unit
.refineOrDie { case e: InvocationTargetException =>
e.getTargetException()
}
.refineOrDie { case e: InvocationTargetException => e.getTargetException() }
.exit
} yield assert(r)(succeeds(isUnit))
)
Expand Down

0 comments on commit a4359e9

Please sign in to comment.