Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Polymorphic async/await implementation #1924

Merged
merged 20 commits into from
May 31, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 14 additions & 9 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,12 @@ lazy val kernel = crossProject(JSPlatform, JVMPlatform)
libraryDependencies += "org.specs2" %%% "specs2-core" % Specs2Version % Test)
.settings(dottyLibrarySettings)
.settings(libraryDependencies += "org.typelevel" %%% "cats-core" % CatsVersion)
.jsSettings(
Compile / doc / sources := {
if (isDotty.value)
Seq()
else
(Compile / doc / sources).value
})
.jsSettings(Compile / doc / sources := {
if (isDotty.value)
Seq()
else
(Compile / doc / sources).value
})

/**
* Reference implementations (including a pure ConcurrentBracket), generic ScalaCheck
Expand Down Expand Up @@ -303,7 +302,8 @@ lazy val tests = crossProject(JSPlatform, JVMPlatform)
name := "cats-effect-tests",
libraryDependencies ++= Seq(
"org.typelevel" %%% "discipline-specs2" % DisciplineVersion % Test,
"org.typelevel" %%% "cats-kernel-laws" % CatsVersion % Test)
"org.typelevel" %%% "cats-kernel-laws" % CatsVersion % Test),
scalacOptions ++= List("-Xasync")
)
.jvmSettings(
Test / fork := true,
Expand All @@ -329,7 +329,12 @@ lazy val std = crossProject(JSPlatform, JVMPlatform)
else
"org.specs2" %%% "specs2-scalacheck" % Specs2Version % Test
},
libraryDependencies += "org.scalacheck" %%% "scalacheck" % ScalaCheckVersion % Test
libraryDependencies += "org.scalacheck" %%% "scalacheck" % ScalaCheckVersion % Test,
libraryDependencies ++= {
if (!isDotty.value)
Seq("org.scala-lang" % "scala-reflect" % scalaVersion.value % "provided")
else Seq()
}
)

/**
Expand Down
175 changes: 175 additions & 0 deletions std/jvm/src/main/scala-2/AsyncAwait.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* Copyright 2020-2021 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cats.effect.std

import scala.annotation.compileTimeOnly
import scala.reflect.macros.whitebox

import cats.effect.std.Dispatcher
import cats.effect.kernel.Outcome
import cats.effect.kernel.Sync
import cats.effect.kernel.Async
import cats.effect.kernel.syntax.all._

class AsyncAwaitDsl[F[_]](implicit F: Async[F]) {

/**
* Type member used by the macro expansion to recover what `F` is without typetags
*/
type _AsyncContext[A] = F[A]

/**
* Value member used by the macro expansion to recover the Async instance associated to the block.
*/
implicit val _AsyncInstance: Async[F] = F

/**
* Non-blocking await the on result of `awaitable`. This may only be used directly within an enclosing `async` block.
*
* Internally, this will register the remainder of the code in enclosing `async` block as a callback
* in the `onComplete` handler of `awaitable`, and will *not* block a thread.
*/
@compileTimeOnly("[async] `await` must be enclosed in an `async` block")
def await[T](awaitable: F[T]): T =
??? // No implementation here, as calls to this are translated to `onComplete` by the macro.

/**
* Run the block of code `body` asynchronously. `body` may contain calls to `await` when the results of
* a `Future` are needed; this is translated into non-blocking code.
*/
def async[T](body: => T): F[T] = macro AsyncAwaitDsl.asyncImpl[F, T]

}

object AsyncAwaitDsl {

type Callback = Either[Throwable, AnyRef] => Unit

def asyncImpl[F[_], T](
c: whitebox.Context
djspiewak marked this conversation as resolved.
Show resolved Hide resolved
)(body: c.Tree): c.Tree = {
import c.universe._
if (!c.compilerSettings.contains("-Xasync")) {
c.abort(
c.macroApplication.pos,
"The async requires the compiler option -Xasync (supported only by Scala 2.12.12+ / 2.13.3+)"
)
} else
try {
val awaitSym = typeOf[AsyncAwaitDsl[Any]].decl(TermName("await"))
def mark(t: DefDef): Tree = {
import language.reflectiveCalls
c.internal
.asInstanceOf[{
def markForAsyncTransform(
owner: Symbol,
method: DefDef,
awaitSymbol: Symbol,
config: Map[String, AnyRef]
): DefDef
}]
.markForAsyncTransform(
c.internal.enclosingOwner,
t,
awaitSym,
Map.empty
)
}
val name = TypeName("stateMachine$async")
// format: off
q"""
final class $name(dispatcher: _root_.cats.effect.std.Dispatcher[${c.prefix}._AsyncContext], callback: _root_.cats.effect.std.AsyncAwaitDsl.Callback) extends _root_.cats.effect.std.AsyncAwaitStateMachine(dispatcher, callback) {
${mark(q"""override def apply(tr$$async: _root_.cats.effect.kernel.Outcome[${c.prefix}._AsyncContext, _root_.scala.Throwable, _root_.scala.AnyRef]): _root_.scala.Unit = ${body}""")}
}
${c.prefix}._AsyncInstance.recoverWith {
_root_.cats.effect.std.Dispatcher[${c.prefix}._AsyncContext].use { dispatcher =>
${c.prefix}._AsyncInstance.async_[_root_.scala.AnyRef](cb => new $name(dispatcher, cb).start())
}
}{
case _root_.cats.effect.std.AsyncAwaitDsl.CancelBridge =>
${c.prefix}._AsyncInstance.map(${c.prefix}._AsyncInstance.canceled)(_ => null.asInstanceOf[AnyRef])
}.asInstanceOf[${c.macroApplication.tpe}]
"""
} catch {
case e: ReflectiveOperationException =>
c.abort(
c.macroApplication.pos,
"-Xasync is provided as a Scala compiler option, but the async macro is unable to call c.internal.markForAsyncTransform. " + e.getClass.getName + " " + e.getMessage
)
}
}

// A marker exception to communicate cancellation through the async runtime.
object CancelBridge extends Throwable with scala.util.control.NoStackTrace
}

abstract class AsyncAwaitStateMachine[F[_]](
dispatcher: Dispatcher[F],
callback: AsyncAwaitDsl.Callback
)(implicit F: Sync[F]) extends Function1[Outcome[F, Throwable, AnyRef], Unit] {

// FSM translated method
//def apply(v1: Outcome[IO, Throwable, AnyRef]): Unit = ???

private[this] var state$async: Int = 0

/** Retrieve the current value of the state variable */
protected def state: Int = state$async

/** Assign `i` to the state variable */
protected def state_=(s: Int): Unit = state$async = s

protected def completeFailure(t: Throwable): Unit =
callback(Left(t))

protected def completeSuccess(value: AnyRef): Unit = {
callback(Right(value))
}

protected def onComplete(f: F[AnyRef]): Unit = {
dispatcher.unsafeRunAndForget(f.guaranteeCase(outcome => F.delay(this(outcome))))
}

protected def getCompleted(f: F[AnyRef]): Outcome[F, Throwable, AnyRef] = {
val _ = f
null
}

protected def tryGet(tr: Outcome[F, Throwable, AnyRef]): AnyRef =
tr match {
case Outcome.Succeeded(value) =>
// TODO discuss how to propagate "errors"" from other
// error channels than the Async's, such as None
// in OptionT. Maybe some ad-hoc polymorphic construct
// with a custom path-dependent "bridge" exception type...
// ... or something
dispatcher.unsafeRunSync(value)
case Outcome.Errored(e) =>
callback(Left(e))
this // sentinel value to indicate the dispatch loop should exit.
case Outcome.Canceled() =>
callback(Left(AsyncAwaitDsl.CancelBridge))
this
}

def start(): Unit = {
// Required to kickstart the async state machine.
// `def apply` does not consult its argument when `state == 0`.
apply(null)
}

}
121 changes: 121 additions & 0 deletions tests/jvm/src/test/scala-2/cats/effect/std/AsyncAwaitSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright 2020-2021 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cats.effect
package std

import scala.concurrent.duration._
import cats.syntax.all._
import cats.data.Kleisli

class AsyncAwaitSpec extends BaseSpec {

"IOAsyncAwait" should {
object IOAsyncAwait extends cats.effect.std.AsyncAwaitDsl[IO]
import IOAsyncAwait.{await => ioAwait, _}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

await collides with something in FutureMatchers.


"work on success" in real {

val io = IO.sleep(100.millis) >> IO.pure(1)

val program = async(ioAwait(io) + ioAwait(io))

program.flatMap { res =>
IO {
res must beEqualTo(2)
}
}
}

"propagate errors outward" in real {

case object Boom extends Throwable
val io = IO.raiseError[Int](Boom)

val program = async(ioAwait(io))

program.attempt.flatMap { res =>
IO {
res must beEqualTo(Left(Boom))
}
}
}

"propagate canceled outcomes outward" in real {

val io = IO.canceled

val program = async(ioAwait(io))

program.start.flatMap(_.join).flatMap { res =>
IO {
res must beEqualTo(Outcome.canceled[IO, Throwable, Unit])
}
}
}

"be cancellable" in real {

val program = for {
ref <- Ref[IO].of(0)
_ <- async { ioAwait(IO.sleep(100.millis) *> ref.update(_ + 1)) }
.start
.flatMap(_.cancel)
_ <- IO.sleep(200.millis)
result <- ref.get
} yield {
result
}

program.flatMap { res =>
IO {
res must beEqualTo(0)
}
}

}

"suspend side effects" in real {
var x = 0
val program = async(x += 1)

for {
before <- IO(x must beEqualTo(0))
_ <- program
after <- IO(x must beEqualTo(1))
} yield before && after
}
}

"KleisliAsyncAwait" should {
type F[A] = Kleisli[IO, Int, A]
object KleisliAsyncAwait extends cats.effect.std.AsyncAwaitDsl[F]
import KleisliAsyncAwait.{await => kAwait, _}

"work on successes" in real {
val io = Temporal[F].sleep(100.millis) >> Kleisli(x => IO.pure(x + 1))

val program = async(kAwait(io) + kAwait(io))

program.run(0).flatMap { res =>
IO {
res must beEqualTo(2)
}
}
}
}

}