Skip to content

Commit

Permalink
Add ZSink#zipPar (#1344)
Browse files Browse the repository at this point in the history
* add Sink.zipPar

* ZSink.zipPar scaladoc

* fixes after code review
  • Loading branch information
simpadjo authored and iravid committed Aug 7, 2019
1 parent 3f9d372 commit 5e8c9ae
Show file tree
Hide file tree
Showing 2 changed files with 235 additions and 0 deletions.
148 changes: 148 additions & 0 deletions streams-tests/jvm/src/test/scala/zio/stream/SinkSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import zio.clock.Clock
import zio.duration._
import zio.test.mock.MockClock
import java.util.concurrent.TimeUnit
import org.specs2.matcher.MatchResult
import org.specs2.matcher.describe.Diffable

class SinkSpec(implicit ee: org.specs2.concurrent.ExecutionEnv)
extends TestRuntime
Expand Down Expand Up @@ -172,6 +174,17 @@ class SinkSpec(implicit ee: org.specs2.concurrent.ExecutionEnv)
zipLeft (<*)
happy path $zipLeftHappyPath

zipPar
happy path 1 $zipParHappyPathBothDone
happy path 2 $zipParHappyPathOneNonterm
happy path 3 $zipParHappyPathBothNonterm
extract error $zipParErrorExtract
step error $zipParErrorStep
init error $zipParErrorInit
both error $zipParErrorBoth
remainder corner case 1 $zipParRemainderWhenCompleteSeparately
remainder corner case 2 $zipParRemainderWhenCompleteTogether

zipRight (*>)
happy path $zipRightHappyPath

Expand Down Expand Up @@ -243,6 +256,31 @@ class SinkSpec(implicit ee: org.specs2.concurrent.ExecutionEnv)
def extract(state: State) = IO.fail("Ouch")
}

/** Searches for the `target` element in the stream.
* When met - accumulates next `accumulateAfterMet` elements and returns as `leftover`
* If `target` is not met - returns `default` with empty `leftover`
*/
private def sinkWithLeftover[A](target: A, accumulateAfterMet: Int, default: A) = new ZSink[Any, String, A, A, A] {
override type State = Option[List[A]]

override def extract(state: Option[List[A]]): ZIO[Any, String, A] =
UIO.succeed(if (state.isEmpty) default else target)

override def initial: ZIO[Any, String, Step[Option[List[A]], Nothing]] = UIO.succeed(Step.more(None))

override def step(state: Option[List[A]], a: A): ZIO[Any, String, Step[Option[List[A]], A]] =
state match {
case None =>
val st = if (a == target) Some(Nil) else None
UIO.succeed(Step.more(st))
case Some(acc) =>
if (acc.length >= accumulateAfterMet)
UIO.succeed(Step.done(state, Chunk.fromIterable(acc)))
else
UIO.succeed(Step.more(Some(acc :+ a)))
}
}

private def sinkIteration[R, E, A0, A, B](sink: ZSink[R, E, A0, A, B], a: A) =
for {
init <- sink.initial
Expand Down Expand Up @@ -800,6 +838,116 @@ class SinkSpec(implicit ee: org.specs2.concurrent.ExecutionEnv)
unsafeRun(sinkIteration(sink, 1).map(_ must_=== "1Hello"))
}

private object ZipParLaws {
def coherence[A, B: Diffable, C: Diffable](
s: Stream[String, A],
sink1: ZSink[Any, String, A, A, B],
sink2: ZSink[Any, String, A, A, C]
): MatchResult[Either[String, Any]] =
unsafeRun {
for {
zb <- s.run(sink1).either
zc <- s.run(sink2).either
zbc <- s.run(sink1.zipPar(sink2)).either
} yield {
zbc match {
case Left(e) => (zb must beLeft(e)) or (zc must beLeft(e))
case Right((b, c)) => (zb must beRight(b)) and (zc must beRight(c))
}
}
}

def swap[A, B: Diffable, C: Diffable](
s: Stream[String, A],
sink1: ZSink[Any, String, A, A, B],
sink2: ZSink[Any, String, A, A, C]
) =
unsafeRun {
for {
res <- s.run(sink1.zipPar(sink2).zip(ZSink.collectAll[A])).either
swapped <- s.run(sink2.zipPar(sink1).zip(ZSink.collectAll[A])).either
} yield {
swapped must_=== res.map {
case ((b, c), rem) => ((c, b), rem)
}
}
}

def remainders[A, B: Diffable, C: Diffable](
s: Stream[String, A],
sink1: ZSink[Any, String, A, A, B],
sink2: ZSink[Any, String, A, A, C]
): MatchResult[AnyVal] =
unsafeRun {
val maybeProp = for {
rem1 <- s.run(sink1.zipRight(ZSink.collectAll[A]))
rem2 <- s.run(sink2.zipRight(ZSink.collectAll[A]))
rem <- s.run(sink1.zipPar(sink2).zipRight(ZSink.collectAll[A]))
} yield {
val (longer, shorter) = if (rem1.length <= rem2.length) (rem2, rem1) else (rem1, rem2)
longer must_=== rem
rem.endsWith(shorter) must_=== true
}
//irrelevant if an error occurred
maybeProp.catchAll(_ => UIO.succeed(1 must_=== 1))
}

def laws[A, B: Diffable, C: Diffable](
s: Stream[String, A],
sink1: ZSink[Any, String, A, A, B],
sink2: ZSink[Any, String, A, A, C]
): MatchResult[Any] =
coherence(s, sink1, sink2) and remainders(s, sink1, sink2) and swap(s, sink1, sink2)
}

private def zipParHappyPathBothDone = {
val sink1 = ZSink.collectAllWhile[Int](_ < 5)
val sink2 = ZSink.collectAllWhile[Int](_ < 3)
ZipParLaws.laws(Stream(1, 2, 3, 4, 5, 6), sink1, sink2)
}

private def zipParHappyPathOneNonterm = {
val sink1 = ZSink.collectAllWhile[Int](_ < 5)
val sink2 = ZSink.collectAllWhile[Int](_ < 30)
ZipParLaws.laws(Stream(1, 2, 3, 4, 5, 6), sink1, sink2)
}

private def zipParHappyPathBothNonterm = {
val sink1 = ZSink.collectAllWhile[Int](_ < 50)
val sink2 = ZSink.collectAllWhile[Int](_ < 30)
ZipParLaws.laws(Stream(1, 2, 3, 4, 5, 6), sink1, sink2)
}

private def zipParErrorExtract = {
val sink1 = ZSink.collectAllWhile[Int](_ < 5)
ZipParLaws.laws(Stream(1, 2, 3, 4, 5, 6), sink1, extractErrorSink)
}

private def zipParErrorStep = {
val sink1 = ZSink.collectAllWhile[Int](_ < 5)
ZipParLaws.laws(Stream(1, 2, 3, 4, 5, 6), sink1, stepErrorSink)
}

private def zipParErrorInit = {
val sink1 = ZSink.collectAllWhile[Int](_ < 5)
ZipParLaws.laws(Stream(1, 2, 3, 4, 5, 6), sink1, initErrorSink)
}

private def zipParErrorBoth =
ZipParLaws.laws(Stream(1, 2, 3, 4, 5, 6), stepErrorSink, initErrorSink)

private def zipParRemainderWhenCompleteTogether = {
val sink1 = sinkWithLeftover(2, 3, -42)
val sink2 = sinkWithLeftover(2, 4, -42)
ZipParLaws.laws(Stream(1, 2, 3, 4, 5, 6), sink1, sink2)
}

private def zipParRemainderWhenCompleteSeparately = {
val sink1 = sinkWithLeftover(3, 1, -42)
val sink2 = sinkWithLeftover(2, 4, -42)
ZipParLaws.laws(Stream(1, 2, 3, 4, 5, 6), sink1, sink2)
}

private def foldLeft =
prop { (s: Stream[String, Int], f: (String, Int) => String, z: String) =>
unsafeRunSync(s.run(ZSink.foldLeft(z)(f))) must_=== slurp(s).map(_.foldLeft(z)(f))
Expand Down
87 changes: 87 additions & 0 deletions streams/shared/src/main/scala/zio/stream/ZSink.scala
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,93 @@ trait ZSink[-R, +E, +A0, -A, +B] { self =>
}
}

/**
* Runs both sinks in parallel on the input and combines the results into a Tuple.
*/
final def zipPar[R1 <: R, E1 >: E, A2 >: A0, A1 <: A, C](
that: ZSink[R1, E1, A2, A1, C]
): ZSink[R1, E1, A2, A1, (B, C)] =
new ZSink[R1, E1, A2, A1, (B, C)] {
type State = (Either[B, self.State], Either[C, that.State])

override def extract(state: State): ZIO[R1, E1, (B, C)] = {
val b: ZIO[R, E, B] = state._1.fold(ZIO.succeed, self.extract)
val c: ZIO[R1, E1, C] = state._2.fold(ZIO.succeed, that.extract)
b.zipPar(c)
}

override def initial: ZIO[R1, E1, Step[State, Nothing]] =
self.initial.flatMap { s1 =>
that.initial.flatMap { s2 =>
(Step.cont(s1), Step.cont(s2)) match {
case (false, false) =>
val zb = self.extract(Step.state(s1))
val zc = that.extract(Step.state(s2))
zb.zipWithPar(zc)((b, c) => Step.done((Left(b), Left(c)), Chunk.empty))

case (false, true) =>
val zb = self.extract(Step.state(s1))
zb.map(b => Step.more((Left(b), Right(Step.state(s2)))))

case (true, false) =>
val zc = that.extract(Step.state(s2))
zc.map(c => Step.more((Right(Step.state(s1)), Left(c))))

case (true, true) =>
ZIO.succeed(Step.more((Right(Step.state(s1)), Right(Step.state(s2)))))
}
}
}

override def step(state: State, a: A1): ZIO[R1, E1, Step[State, A2]] = {
val firstResult: ZIO[R, E, Either[(B, Option[Chunk[A2]]), self.State]] = state._1.fold(
b => ZIO.succeed(Left((b, None))),
s =>
self
.step(s, a)
.flatMap { (st: Step[ZSink.this.State, A0]) =>
if (Step.cont(st))
ZIO.succeed(Right(Step.state(st)))
else
self
.extract(Step.state(st))
.map(b => Left((b, Some(Step.leftover(st)))))
}
)

val secondResult: ZIO[R1, E1, Either[(C, Option[Chunk[A2]]), that.State]] = state._2.fold(
c => ZIO.succeed(Left((c, None))),
s =>
that
.step(s, a)
.flatMap { st =>
if (Step.cont(st))
ZIO.succeed(Right(Step.state(st)))
else
that
.extract(Step.state(st))
.map(c => {
val leftover: Chunk[A2] = (Step.leftover(st))
Left((c, Some(leftover)))
})
}
)

firstResult.zipPar(secondResult).map {
case (Left((b, rem1)), Left((c, rem2))) =>
val minLeftover =
if (rem1.isEmpty && rem2.isEmpty) Chunk.empty else (rem1.toList ++ rem2.toList).minBy(_.length)
Step.done((Left(b), Left(c)), minLeftover)

case (Left((b, _)), Right(s2)) =>
Step.more((Left(b), Right(s2)))

case (r: Right[_, _], Left((c, _))) => Step.more((r.asInstanceOf[Either[B, self.State]], Left(c)))
case rights @ (Right(_), Right(_)) => Step.more(rights.asInstanceOf[State])
}
}
}

/**
* Produces a sink consuming all the elements of type `A` as long as
* they verify the predicate `pred`.
Expand Down

0 comments on commit 5e8c9ae

Please sign in to comment.