Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel map2 optimization #3428

Merged
merged 12 commits into from
Apr 20, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

package cats.effect.kernel.instances

import cats.{~>, Align, Applicative, CommutativeApplicative, Eval, Functor, Monad, Parallel}
import cats.{~>, Align, Applicative, CommutativeApplicative, Functor, Monad, Parallel}
import cats.data.Ior
import cats.effect.kernel.{GenSpawn, Outcome, ParallelF}
import cats.effect.kernel.{GenSpawn, ParallelF}
import cats.implicits._

trait GenSpawnInstances {
Expand All @@ -41,7 +41,6 @@ trait GenSpawnInstances {
new (M ~> F) {
def apply[A](ma: M[A]): F[A] = ParallelF[M, A](ma)
}

}

implicit def commutativeApplicativeForParallelF[F[_], E](
Expand All @@ -51,134 +50,10 @@ trait GenSpawnInstances {
final override def pure[A](a: A): ParallelF[F, A] = ParallelF(F.pure(a))

final override def map2[A, B, Z](fa: ParallelF[F, A], fb: ParallelF[F, B])(
f: (A, B) => Z): ParallelF[F, Z] =
f: (A, B) => Z): ParallelF[F, Z] = {
ParallelF(
F.uncancelable { poll =>
for {
fiberA <- F.start(ParallelF.value(fa))
fiberB <- F.start(ParallelF.value(fb))

// start a pair of supervisors to ensure that the opposite is canceled on error
_ <- F start {
fiberB.join flatMap {
case Outcome.Succeeded(_) => F.unit
case _ => fiberA.cancel
}
}

_ <- F start {
fiberA.join flatMap {
case Outcome.Succeeded(_) => F.unit
case _ => fiberB.cancel
}
}

a <- F
.onCancel(poll(fiberA.join), bothUnit(fiberA.cancel, fiberB.cancel))
.flatMap[A] {
case Outcome.Succeeded(fa) =>
fa

case Outcome.Errored(e) =>
fiberB.cancel *> F.raiseError(e)

case Outcome.Canceled() =>
fiberB.cancel *> poll {
fiberB.join flatMap {
case Outcome.Succeeded(_) | Outcome.Canceled() =>
F.canceled *> F.never
case Outcome.Errored(e) =>
F.raiseError(e)
}
}
}

z <- F.onCancel(poll(fiberB.join), fiberB.cancel).flatMap[Z] {
case Outcome.Succeeded(fb) =>
fb.map(b => f(a, b))

case Outcome.Errored(e) =>
F.raiseError(e)

case Outcome.Canceled() =>
poll {
fiberA.join flatMap {
case Outcome.Succeeded(_) | Outcome.Canceled() =>
F.canceled *> F.never
case Outcome.Errored(e) =>
F.raiseError(e)
}
}
}
} yield z
}
)

final override def map2Eval[A, B, Z](fa: ParallelF[F, A], fb: Eval[ParallelF[F, B]])(
f: (A, B) => Z): Eval[ParallelF[F, Z]] =
Eval.now(
ParallelF(
F.uncancelable { poll =>
for {
fiberA <- F.start(ParallelF.value(fa))
fiberB <- F.start(ParallelF.value(fb.value))

// start a pair of supervisors to ensure that the opposite is canceled on error
_ <- F start {
fiberB.join flatMap {
case Outcome.Succeeded(_) => F.unit
case _ => fiberA.cancel
}
}

_ <- F start {
fiberA.join flatMap {
case Outcome.Succeeded(_) => F.unit
case _ => fiberB.cancel
}
}

a <- F
.onCancel(poll(fiberA.join), bothUnit(fiberA.cancel, fiberB.cancel))
.flatMap[A] {
case Outcome.Succeeded(fa) =>
fa

case Outcome.Errored(e) =>
fiberB.cancel *> F.raiseError(e)

case Outcome.Canceled() =>
fiberB.cancel *> poll {
fiberB.join flatMap {
case Outcome.Succeeded(_) | Outcome.Canceled() =>
F.canceled *> F.never
case Outcome.Errored(e) =>
F.raiseError(e)
}
}
}

z <- F.onCancel(poll(fiberB.join), fiberB.cancel).flatMap[Z] {
case Outcome.Succeeded(fb) =>
fb.map(b => f(a, b))

case Outcome.Errored(e) =>
F.raiseError(e)

case Outcome.Canceled() =>
poll {
fiberA.join flatMap {
case Outcome.Succeeded(_) | Outcome.Canceled() =>
F.canceled *> F.never
case Outcome.Errored(e) =>
F.raiseError(e)
}
}
}
} yield z
}
)
)
F.both(ParallelF.value(fa), ParallelF.value(fb)).map { case (a, b) => f(a, b) })
}

final override def ap[A, B](ff: ParallelF[F, A => B])(
fa: ParallelF[F, A]): ParallelF[F, B] =
Expand All @@ -194,10 +69,6 @@ trait GenSpawnInstances {

final override def unit: ParallelF[F, Unit] =
ParallelF(F.unit)

// assumed to be uncancelable
private[this] def bothUnit(a: F[Unit], b: F[Unit]): F[Unit] =
F.start(a).flatMap(f => b *> f.join.void)
}

implicit def alignForParallelF[F[_], E](implicit F: GenSpawn[F, E]): Align[ParallelF[F, *]] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ class PureConcSpec extends Specification with Discipline with BaseSpec {
pure.run((F.raiseError[Unit](42), F.never[Unit]).parTupled) mustEqual Outcome.Errored(42)
}

"short-circuit on canceled" in {
pure.run((F.never[Unit], F.canceled).parTupled.start.flatMap(_.join)) mustEqual Outcome
.Succeeded(Some(Outcome.canceled[F, Nothing, Unit]))
pure.run((F.canceled, F.never[Unit]).parTupled.start.flatMap(_.join)) mustEqual Outcome
.Succeeded(Some(Outcome.canceled[F, Nothing, Unit]))
}

"not run forever on chained product" in {
import cats.effect.kernel.Par.ParallelF

Expand Down
75 changes: 75 additions & 0 deletions tests/shared/src/test/scala/cats/effect/IOSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,81 @@ class IOSpec extends BaseSpec with Discipline with IOPlatformSpecification {
(IO.raiseError[Unit](TestException), IO.never[Unit]).parTupled.void must failAs(
TestException)
}

"short-circuit on canceled" in ticked { implicit ticker =>
(IO.never[Unit], IO.canceled)
.parTupled
.start
.flatMap(_.join.map(_.isCanceled)) must completeAs(true)
(IO.canceled, IO.never[Unit])
.parTupled
.start
.flatMap(_.join.map(_.isCanceled)) must completeAs(true)
}

"run finalizers when canceled" in ticked { implicit ticker =>
val tsk = IO.ref(0).flatMap { ref =>
val t = IO.never[Unit].onCancel(ref.update(_ + 1))
for {
fib <- (t, t).parTupled.start
_ <- IO { ticker.ctx.tickAll() }
_ <- fib.cancel
c <- ref.get
} yield c
}

tsk must completeAs(2)
}

"run right side finalizer when canceled (and left side already completed)" in ticked {
implicit ticker =>
val tsk = IO.ref(0).flatMap { ref =>
for {
fib <- (IO.unit, IO.never[Unit].onCancel(ref.update(_ + 1))).parTupled.start
_ <- IO { ticker.ctx.tickAll() }
_ <- fib.cancel
c <- ref.get
} yield c
}

tsk must completeAs(1)
}

"run left side finalizer when canceled (and right side already completed)" in ticked {
implicit ticker =>
val tsk = IO.ref(0).flatMap { ref =>
for {
fib <- (IO.never[Unit].onCancel(ref.update(_ + 1)), IO.unit).parTupled.start
_ <- IO { ticker.ctx.tickAll() }
_ <- fib.cancel
c <- ref.get
} yield c
}

tsk must completeAs(1)
}

"complete if both sides complete" in ticked { implicit ticker =>
val tsk = (
IO.sleep(2.seconds).as(20),
IO.sleep(3.seconds).as(22)
).parTupled.map { case (l, r) => l + r }

tsk must completeAs(42)
}

"not run forever on chained product" in ticked { implicit ticker =>
import cats.effect.kernel.Par.ParallelF

case object TestException extends RuntimeException

val fa: IO[String] = IO.pure("a")
val fb: IO[String] = IO.pure("b")
val fc: IO[Unit] = IO.raiseError[Unit](TestException)
val tsk =
ParallelF.value(ParallelF(fa).product(ParallelF(fb)).product(ParallelF(fc))).void
tsk must failAs(TestException)
}
}

"miscellaneous" should {
Expand Down