Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Polymorphic async/await implementation #1924

Merged
merged 20 commits into from
May 31, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 24 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,16 @@ lazy val tests = crossProject(JSPlatform, JVMPlatform)
name := "cats-effect-tests",
libraryDependencies ++= Seq(
"org.typelevel" %%% "discipline-specs2" % DisciplineVersion % Test,
"org.typelevel" %%% "cats-kernel-laws" % CatsVersion % Test)
"org.typelevel" %%% "cats-kernel-laws" % CatsVersion % Test),
scalacOptions ++= List("-Xasync"),
Test / unmanagedSourceDirectories ++= {
if (!isDotty.value)
Seq(
(Compile / baseDirectory)
.value
.getParentFile() / "shared" / "src" / "test" / "scala-2")
else Seq()
}
)
.jvmSettings(
Test / fork := true,
Expand All @@ -337,7 +346,20 @@ lazy val std = crossProject(JSPlatform, JVMPlatform)
else
"org.specs2" %%% "specs2-scalacheck" % Specs2Version % Test
},
libraryDependencies += "org.scalacheck" %%% "scalacheck" % ScalaCheckVersion % Test
libraryDependencies += "org.scalacheck" %%% "scalacheck" % ScalaCheckVersion % Test,
libraryDependencies ++= {
if (!isDotty.value)
Seq("org.scala-lang" % "scala-reflect" % scalaVersion.value % "provided")
else Seq()
},
Compile / unmanagedSourceDirectories ++= {
if (!isDotty.value)
Seq(
(Compile / baseDirectory)
.value
.getParentFile() / "shared" / "src" / "main" / "scala-2")
else Seq()
}
)

/**
Expand Down
189 changes: 189 additions & 0 deletions std/shared/src/main/scala-2/AsyncAwait.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
/*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor thing, but can we rename this file in lower-case to reflect the fact that it contains multiple top-level members? Like asyncAwait.scala.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

* Copyright 2020-2021 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cats.effect.std

import scala.annotation.compileTimeOnly
import scala.reflect.macros.whitebox

import cats.effect.kernel.Async
import cats.effect.kernel.syntax.all._
import cats.syntax.all._
import cats.effect.kernel.Outcome.Canceled
import cats.effect.kernel.Outcome.Errored
import cats.effect.kernel.Outcome.Succeeded

class AsyncAwaitDsl[F[_]](implicit F: Async[F]) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add some scaladoc being like "uh, this is really really unstable, only works on Scala 2.x, and will likely change and/or be deprecated in the future"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love it.


/**
* Type member used by the macro expansion to recover what `F` is without typetags
*/
type _AsyncContext[A] = F[A]

/**
* Value member used by the macro expansion to recover the Async instance associated to the block.
*/
implicit val _AsyncInstance: Async[F] = F

/**
* Non-blocking await the on result of `awaitable`. This may only be used directly within an enclosing `async` block.
*
* Internally, this will register the remainder of the code in enclosing `async` block as a callback
* in the `onComplete` handler of `awaitable`, and will *not* block a thread.
*/
@compileTimeOnly("[async] `await` must be enclosed in an `async` block")
def await[T](awaitable: F[T]): T =
??? // No implementation here, as calls to this are translated to `onComplete` by the macro.

/**
* Run the block of code `body` asynchronously. `body` may contain calls to `await` when the results of
* a `Future` are needed; this is translated into non-blocking code.
*/
def async[T](body: => T): F[T] = macro AsyncAwaitDsl.asyncImpl[F, T]

}

object AsyncAwaitDsl {

type CallbackTarget[F[_]] = F[AnyRef]
type Callback[F[_]] = Either[Throwable, CallbackTarget[F]] => Unit
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this have to be public?

Copy link
Contributor Author

@Baccata Baccata May 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does, because it's used in macro expansion, so the client code needs to have visibility over those types.

These aliases have been really valuable during iteration as otherwise, any change to the types lead to having to amend the quasiquote expression in several places which is really, really annoying due to not getting compile errors until the macro expansion. I'd rather keep them for maintainability reasons but will obviously abide if you make an executive call to remove them.

In the meantime, I removed CallbackTarget as it's useless in this iteration and renamed Callback to AwaitCallback (8971556)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm cool with keeping them. Just wanted to sanity check.


// Outcome of an await block. Either a failed algebraic computation,
// or a successful value accompanied by a "summary" computation.
//
// Allows to short-circuit the async/await state machine when relevant
// (think OptionT.none) and track algebraic information that may otherwise
// get lost during Dispatcher#unsafeRun calls (WriterT/IorT logs).
type AwaitOutcome[F[_]] = Either[F[AnyRef], (F[Unit], AnyRef)]
djspiewak marked this conversation as resolved.
Show resolved Hide resolved

def asyncImpl[F[_], T](
c: whitebox.Context
)(body: c.Tree): c.Tree = {
import c.universe._
if (!c.compilerSettings.contains("-Xasync")) {
c.abort(
c.macroApplication.pos,
"The async requires the compiler option -Xasync (supported only by Scala 2.12.12+ / 2.13.3+)"
)
} else
try {
val awaitSym = typeOf[AsyncAwaitDsl[Any]].decl(TermName("await"))
def mark(t: DefDef): Tree = {
c.internal
.asInstanceOf[{
def markForAsyncTransform(
owner: Symbol,
method: DefDef,
awaitSymbol: Symbol,
config: Map[String, AnyRef]
): DefDef
}]
.markForAsyncTransform(
c.internal.enclosingOwner,
t,
awaitSym,
Map.empty
)
}
val name = TypeName("stateMachine$async")
// format: off
q"""
final class $name(dispatcher: _root_.cats.effect.std.Dispatcher[${c.prefix}._AsyncContext], callback: _root_.cats.effect.std.AsyncAwaitDsl.Callback[${c.prefix}._AsyncContext]) extends _root_.cats.effect.std.AsyncAwaitStateMachine(dispatcher, callback) {
${mark(q"""override def apply(tr$$async: _root_.cats.effect.std.AsyncAwaitDsl.AwaitOutcome[${c.prefix}._AsyncContext]): _root_.scala.Unit = ${body}""")}
}
${c.prefix}._AsyncInstance.flatten {
_root_.cats.effect.std.Dispatcher[${c.prefix}._AsyncContext].use { dispatcher =>
${c.prefix}._AsyncInstance.async_[${c.prefix}._AsyncContext[AnyRef]](cb => new $name(dispatcher, cb).start())
}
}.asInstanceOf[${c.macroApplication.tpe}]
"""
} catch {
case e: ReflectiveOperationException =>
c.abort(
c.macroApplication.pos,
"-Xasync is provided as a Scala compiler option, but the async macro is unable to call c.internal.markForAsyncTransform. " + e.getClass.getName + " " + e.getMessage
)
}
}

}

abstract class AsyncAwaitStateMachine[F[_]](
dispatcher: Dispatcher[F],
callback: AsyncAwaitDsl.Callback[F]
)(implicit F: Async[F]) extends Function1[AsyncAwaitDsl.AwaitOutcome[F], Unit] {

// FSM translated method
//def apply(v1: AsyncAwaitDsl.AwaitOutcome[F]): Unit = ???

// Resorting to mutation to track algebraic product effects (like WriterT),
// since the information they carry would otherwise get lost on every dispatch.
private[this] var summary : F[Unit] = F.unit

private[this] var state$async: Int = 0

/** Retrieve the current value of the state variable */
protected def state: Int = state$async

/** Assign `i` to the state variable */
protected def state_=(s: Int): Unit = state$async = s

protected def completeFailure(t: Throwable): Unit =
callback(Left(t))

protected def completeSuccess(value: AnyRef): Unit = {
callback(Right(F.as(summary, value)))
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
protected def completeSuccess(value: AnyRef): Unit = {
callback(Right(F.as(summary, value)))
}
protected def completeSuccess(value: AnyRef): Unit =
callback(Right(F.as(summary, value)))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


protected def onComplete(f: F[AnyRef]): Unit = {
dispatcher.unsafeRunAndForget {
// Resorting to mutation to extract the "happy path" value from the monadic context,
// as inspecting the Succeeded outcome using dispatcher is risky on algebraic sums,
// such as OptionT, EitherT, ...
var awaitedValue: Option[AnyRef] = None
(summary *> f).flatTap(r => F.delay{awaitedValue = Some(r)}).start.flatMap(_.join).flatMap {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It actually surprises me that scalafmt gives this line a pass.

Suggested change
(summary *> f).flatTap(r => F.delay{awaitedValue = Some(r)}).start.flatMap(_.join).flatMap {
(summary *> f).flatTap(r => F.delay { awaitedValue = Some(r) }).start.flatMap(_.join).flatMap {

Also, why is the .start.flatMap(_.join) necessary? Why not just guaranteeCase? (is this just working around #2013?) At the very least we should use uncancelable to ensure the Fiber doesn't leak.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It actually surprises me that scalafmt gives this line a pass.

I had disabled formatting before the quasiquotes and forgotten to re-enable it. Fixed in 8971556

Also, why is the .start.flatMap(_.join) necessary?

2 birds, one stone :

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the very least we should use uncancelable to ensure the Fiber doesn't leak.

Is this correct : 0220ebf ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost! Commented on the commit

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what I was pointing here. When #2013 is fixed, could be replaced by .forceR(F.void) I suppose

I'm not sure I understand why the forceR? Asynchronous non-terminating fibers don't take any resources and are eventually GC'd when no longer referenced, so they aren't really a problem per se.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Asynchronous non-terminating fibers don't take any resources

From the point of view of someone who doesn't know the innards of Dispatcher and how it relates to the effect's runtime, protecting against potential non-terminating behaviour seemed like a good reflex 😄. Also Dispatcher being polymorphic, can you guarantee that it'll be the case for all effect types ? (I reckon it'd be pretty bad if it wasn't the case, but still).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you guarantee that it'll be the case for all effect types ? (I reckon it'd be pretty bad if it wasn't the case, but still).

I think that any F for which non-termination in Dispatcher[F] is problematic would also be an F for which fibers are problematic. I think.

case Canceled() => F.delay(this(Left(F.canceled.asInstanceOf[F[AnyRef]])))
case Errored(e) => F.delay(this(Left(F.raiseError(e))))
case Succeeded(awaitOutcome) => awaitedValue match {
case Some(v) => F.delay(this(Right(awaitOutcome.void -> v)))
case None => F.delay(this(Left(awaitOutcome)))
}
}
}
}

protected def getCompleted(f: F[AnyRef]): AsyncAwaitDsl.AwaitOutcome[F] = {
val _ = f
null
}

protected def tryGet(awaitOutcome: AsyncAwaitDsl.AwaitOutcome[F]): AnyRef =
awaitOutcome match {
case Right((newSummary, value)) =>
summary = newSummary
value
case Left(monadicStop) =>
callback(Right(monadicStop))
this // sentinel value to indicate the dispatch loop should exit.
}

def start(): Unit = {
// Required to kickstart the async state machine.
// `def apply` does not consult its argument when `state == 0`.
apply(null)
}

}