Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Polymorphic async/await implementation #1924

Merged
merged 20 commits into from
May 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,16 @@ lazy val tests = crossProject(JSPlatform, JVMPlatform)
name := "cats-effect-tests",
libraryDependencies ++= Seq(
"org.typelevel" %%% "discipline-specs2" % DisciplineVersion % Test,
"org.typelevel" %%% "cats-kernel-laws" % CatsVersion % Test)
"org.typelevel" %%% "cats-kernel-laws" % CatsVersion % Test),
scalacOptions ++= List("-Xasync"),
Test / unmanagedSourceDirectories ++= {
if (!isDotty.value)
Seq(
(Compile / baseDirectory)
.value
.getParentFile() / "shared" / "src" / "test" / "scala-2")
else Seq()
}
)
.jvmSettings(
Test / fork := true,
Expand All @@ -337,7 +346,20 @@ lazy val std = crossProject(JSPlatform, JVMPlatform)
else
"org.specs2" %%% "specs2-scalacheck" % Specs2Version % Test
},
libraryDependencies += "org.scalacheck" %%% "scalacheck" % ScalaCheckVersion % Test
libraryDependencies += "org.scalacheck" %%% "scalacheck" % ScalaCheckVersion % Test,
libraryDependencies ++= {
if (!isDotty.value)
Seq("org.scala-lang" % "scala-reflect" % scalaVersion.value % "provided")
else Seq()
},
Compile / unmanagedSourceDirectories ++= {
if (!isDotty.value)
Seq(
(Compile / baseDirectory)
.value
.getParentFile() / "shared" / "src" / "main" / "scala-2")
else Seq()
}
)

/**
Expand Down
223 changes: 223 additions & 0 deletions std/shared/src/main/scala-2/AsyncAwait.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
/*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor thing, but can we rename this file in lower-case to reflect the fact that it contains multiple top-level members? Like asyncAwait.scala.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

* Copyright 2020-2021 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cats.effect.std

import scala.annotation.compileTimeOnly
import scala.reflect.macros.blackbox

import cats.effect.kernel.Async
import cats.effect.kernel.syntax.all._
import cats.syntax.all._
import cats.effect.kernel.Outcome.Canceled
import cats.effect.kernel.Outcome.Errored
import cats.effect.kernel.Outcome.Succeeded

/**
* WARNING: This construct currently only works on scala 2 (2.12.12+ / 2.13.3+),
* relies on an experimental compiler feature enabled by the -Xasync
* scalac option, and should absolutely be considered unstable with
* regards to backward compatibility guarantees (be that source or binary).
*
* Partially applied construct allowing for async/await semantics,
* popularised in other programming languages.
*
* {{{
* object dsl extends AsyncAwaitDsl[IO]
* import dsl._
*
* val io: IO[Int] = ???
* async { await(io) + await(io) }
* }}}
*
* The code is transformed at compile time into a state machine
* that sequentially calls upon a [[Dispatcher]] every time it reaches
* an "await" block.
*/
class AsyncAwaitDsl[F[_]](implicit F: Async[F]) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add some scaladoc being like "uh, this is really really unstable, only works on Scala 2.x, and will likely change and/or be deprecated in the future"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love it.


/**
* Type member used by the macro expansion to recover what `F` is without typetags
*/
type _AsyncContext[A] = F[A]

/**
* Value member used by the macro expansion to recover the Async instance associated to the block.
*/
implicit val _AsyncInstance: Async[F] = F

/**
* Non-blocking await the on result of `awaitable`. This may only be used directly within an enclosing `async` block.
*
* Internally, this transforms the remainder of the code in enclosing `async` block into a callback
* that triggers upon a successful computation outcome. It does *not* block a thread.
*/
@compileTimeOnly("[async] `await` must be enclosed in an `async` block")
def await[T](awaitable: F[T]): T =
??? // No implementation here, as calls to this are translated by the macro.

/**
* Run the block of code `body` asynchronously. `body` may contain calls to `await` when the results of
* a `F` are needed; this is translated into non-blocking code.
*/
def async[T](body: => T): F[T] = macro AsyncAwaitDsl.asyncImpl[F, T]

}

object AsyncAwaitDsl {

type AwaitCallback[F[_]] = Either[Throwable, F[AnyRef]] => Unit

// Outcome of an await block. Either a failed algebraic computation,
// or a successful value accompanied by a "summary" computation.
//
// Allows to short-circuit the async/await state machine when relevant
// (think OptionT.none) and track algebraic information that may otherwise
// get lost during Dispatcher#unsafeRun calls (WriterT/IorT logs).
type AwaitOutcome[F[_]] = Either[F[AnyRef], (F[Unit], AnyRef)]
djspiewak marked this conversation as resolved.
Show resolved Hide resolved

def asyncImpl[F[_], T](
c: blackbox.Context
)(body: c.Expr[T]): c.Expr[F[T]] = {
import c.universe._
if (!c.compilerSettings.contains("-Xasync")) {
c.abort(
c.macroApplication.pos,
"The async requires the compiler option -Xasync (supported only by Scala 2.12.12+ / 2.13.3+)"
)
} else
try {
val awaitSym = typeOf[AsyncAwaitDsl[Any]].decl(TermName("await"))
def mark(t: DefDef): Tree = {
c.internal
.asInstanceOf[{
def markForAsyncTransform(
owner: Symbol,
method: DefDef,
awaitSymbol: Symbol,
config: Map[String, AnyRef]
): DefDef
}]
.markForAsyncTransform(
c.internal.enclosingOwner,
t,
awaitSym,
Map.empty
)
}
val name = TypeName("stateMachine$async")
// format: off
val tree = q"""
final class $name(dispatcher: _root_.cats.effect.std.Dispatcher[${c.prefix}._AsyncContext], callback: _root_.cats.effect.std.AsyncAwaitDsl.AwaitCallback[${c.prefix}._AsyncContext]) extends _root_.cats.effect.std.AsyncAwaitStateMachine(dispatcher, callback) {
${mark(q"""override def apply(tr$$async: _root_.cats.effect.std.AsyncAwaitDsl.AwaitOutcome[${c.prefix}._AsyncContext]): _root_.scala.Unit = ${body}""")}
}
${c.prefix}._AsyncInstance.flatten {
_root_.cats.effect.std.Dispatcher[${c.prefix}._AsyncContext].use { dispatcher =>
${c.prefix}._AsyncInstance.async_[${c.prefix}._AsyncContext[AnyRef]](cb => new $name(dispatcher, cb).start())
}
}.asInstanceOf[${c.macroApplication.tpe}]
"""
// format: on
c.Expr(tree)
} catch {
case e: ReflectiveOperationException =>
c.abort(
c.macroApplication.pos,
"-Xasync is provided as a Scala compiler option, but the async macro is unable to call c.internal.markForAsyncTransform. " + e
.getClass
.getName + " " + e.getMessage
)
}
}

}

abstract class AsyncAwaitStateMachine[F[_]](
dispatcher: Dispatcher[F],
callback: AsyncAwaitDsl.AwaitCallback[F]
)(implicit F: Async[F])
extends Function1[AsyncAwaitDsl.AwaitOutcome[F], Unit] {

// FSM translated method
//def apply(v1: AsyncAwaitDsl.AwaitOutcome[F]): Unit = ???

// Resorting to mutation to track algebraic product effects (like WriterT),
// since the information they carry would otherwise get lost on every dispatch.
private[this] var summary: F[Unit] = F.unit

private[this] var state$async: Int = 0

/**
* Retrieve the current value of the state variable
*/
protected def state: Int = state$async

/**
* Assign `i` to the state variable
*/
protected def state_=(s: Int): Unit = state$async = s

protected def completeFailure(t: Throwable): Unit =
callback(Left(t))

protected def completeSuccess(value: AnyRef): Unit =
callback(Right(F.as(summary, value)))

protected def onComplete(f: F[AnyRef]): Unit = {
dispatcher.unsafeRunAndForget {
// Resorting to mutation to extract the "happy path" value from the monadic context,
// as inspecting the Succeeded outcome using dispatcher is risky on algebraic sums,
// such as OptionT, EitherT, ...
var awaitedValue: Option[AnyRef] = None
F.uncancelable { poll =>
poll(summary *> f)
.flatTap(r => F.delay { awaitedValue = Some(r) })
.start
.flatMap(fiber => poll(fiber.join).onCancel(fiber.cancel))
}.flatMap {
case Canceled() => F.delay(this(Left(F.canceled.asInstanceOf[F[AnyRef]])))
case Errored(e) => F.delay(this(Left(F.raiseError(e))))
case Succeeded(awaitOutcome) =>
awaitedValue match {
case Some(v) => F.delay(this(Right(awaitOutcome.void -> v)))
case None => F.delay(this(Left(awaitOutcome)))
}
}
}
}

protected def getCompleted(f: F[AnyRef]): AsyncAwaitDsl.AwaitOutcome[F] = {
val _ = f
null
}

protected def tryGet(awaitOutcome: AsyncAwaitDsl.AwaitOutcome[F]): AnyRef =
awaitOutcome match {
case Right((newSummary, value)) =>
summary = newSummary
value
case Left(monadicStop) =>
callback(Right(monadicStop))
this // sentinel value to indicate the dispatch loop should exit.
}

def start(): Unit = {
// Required to kickstart the async state machine.
// `def apply` does not consult its argument when `state == 0`.
apply(null)
}

}
Loading