Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 3 additions & 34 deletions src/main/scala/scala/async/Async.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package scala.async

import scala.language.experimental.macros
import scala.reflect.macros.Context
import scala.util.continuations.{cpsParam, reset}

object Async extends AsyncBase {

Expand Down Expand Up @@ -61,9 +60,7 @@ abstract class AsyncBase {
@deprecated("`await` must be enclosed in an `async` block", "0.1")
def await[T](awaitable: futureSystem.Fut[T]): T = ???

def awaitFallback[T, U](awaitable: futureSystem.Fut[T], p: futureSystem.Prom[U]): T @cpsParam[U, Unit] = ???

def fallbackEnabled = false
protected[async] def fallbackEnabled = false

def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = {
import c.universe._
Expand All @@ -72,7 +69,8 @@ abstract class AsyncBase {
val utils = TransformUtils[c.type](c)
import utils.{name, defn}

if (!analyzer.reportUnsupportedAwaits(body.tree) || !fallbackEnabled) {
analyzer.reportUnsupportedAwaits(body.tree)

// Transform to A-normal form:
// - no await calls in qualifiers or arguments,
// - if/match only used in statement position.
Expand Down Expand Up @@ -162,35 +160,6 @@ abstract class AsyncBase {

AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code.tree}")
code
} else {
// replace `await` invocations with `awaitFallback` invocations
val awaitReplacer = new Transformer {
override def transform(tree: Tree): Tree = tree match {
case Apply(fun @ TypeApply(_, List(futArgTpt)), args) if fun.symbol == defn.Async_await =>
val typeApp = treeCopy.TypeApply(fun, Ident(defn.Async_awaitFallback), List(TypeTree(futArgTpt.tpe), TypeTree(body.tree.tpe)))
treeCopy.Apply(tree, typeApp, args.map(arg => c.resetAllAttrs(arg.duplicate)) :+ Ident(name.result))
case _ =>
super.transform(tree)
}
}
val newBody = awaitReplacer.transform(body.tree)

val resetBody = reify {
reset { c.Expr(c.resetAllAttrs(newBody.duplicate)).splice }
}

val futureSystemOps = futureSystem.mkOps(c)
val code = {
val tree = Block(List(
ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree),
futureSystemOps.spawn(resetBody.tree)
), futureSystemOps.promiseToFuture(c.Expr[futureSystem.Prom[T]](Ident(name.result))).tree)
c.Expr[futureSystem.Fut[T]](tree)
}

AsyncUtils.vprintln(s"async CPS fallback transform expands to:\n ${code.tree}")
code
}
}

def logDiagnostics(c: Context)(anfTree: c.Tree, states: Seq[String]) {
Expand Down
31 changes: 0 additions & 31 deletions src/main/scala/scala/async/AsyncWithCPSFallback.scala

This file was deleted.

8 changes: 8 additions & 0 deletions src/main/scala/scala/async/FutureSystem.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ trait FutureSystem {

def spawn(tree: context.Tree): context.Tree =
future(context.Expr[Unit](tree))(execContext).tree

def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]]
}

def mkOps(c: Context): Ops { val context: c.type }
Expand Down Expand Up @@ -101,6 +103,10 @@ object ScalaConcurrentFutureSystem extends FutureSystem {
prom.splice.complete(value.splice)
context.literalUnit.splice
}

def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = reify {
future.splice.asInstanceOf[Fut[A]]
}
}
}

Expand Down Expand Up @@ -145,5 +151,7 @@ object IdentityFutureSystem extends FutureSystem {
prom.splice.a = value.splice.get
context.literalUnit.splice
}

def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = ???
}
}
3 changes: 1 addition & 2 deletions src/main/scala/scala/async/TransformUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,7 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
tpe.member(newTermName(name)).ensuring(_ != NoSymbol)
}

val Async_await = asyncMember("await")
val Async_awaitFallback = asyncMember("awaitFallback")
val Async_await = asyncMember("await")
}

/** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
*/

package scala.async
package continuations

import scala.language.experimental.macros

import scala.reflect.macros.Context
import scala.util.continuations._

trait AsyncBaseWithCPSFallback extends AsyncBase {

/* Fall-back for `await` using CPS plugin.
*
* Note: This method is public, but is intended only for internal use.
*/
def awaitFallback[T](awaitable: futureSystem.Fut[T]): T @cps[futureSystem.Fut[Any]]

override protected[async] def fallbackEnabled = true

/* Implements `async { ... }` using the CPS plugin.
*/
protected def cpsBasedAsyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = {
import c.universe._

def lookupMember(name: String) = {
val asyncTrait = c.mirror.staticClass("scala.async.continuations.AsyncBaseWithCPSFallback")
val tpe = asyncTrait.asType.toType
tpe.member(newTermName(name)).ensuring(_ != NoSymbol)
}

AsyncUtils.vprintln("AsyncBaseWithCPSFallback.cpsBasedAsyncImpl")

val utils = TransformUtils[c.type](c)
val futureSystemOps = futureSystem.mkOps(c)
val awaitSym = utils.defn.Async_await
val awaitFallbackSym = lookupMember("awaitFallback")

// replace `await` invocations with `awaitFallback` invocations
val awaitReplacer = new Transformer {
override def transform(tree: Tree): Tree = tree match {
case Apply(fun @ TypeApply(_, List(futArgTpt)), args) if fun.symbol == awaitSym =>
val typeApp = treeCopy.TypeApply(fun, Ident(awaitFallbackSym), List(TypeTree(futArgTpt.tpe)))
treeCopy.Apply(tree, typeApp, args.map(arg => c.resetAllAttrs(arg.duplicate)))
case _ =>
super.transform(tree)
}
}
val bodyWithAwaitFallback = awaitReplacer.transform(body.tree)

/* generate an expression that looks like this:
reset {
val f = future { ... }
...
val x = awaitFallback(f)
...
future { expr }
}.asInstanceOf[Future[T]]
*/

val bodyWithFuture = {
val tree = bodyWithAwaitFallback match {
case Block(stmts, expr) => Block(stmts, futureSystemOps.spawn(expr))
case expr => futureSystemOps.spawn(expr)
}
c.Expr[futureSystem.Fut[Any]](c.resetAllAttrs(tree.duplicate))
}

val bodyWithReset: c.Expr[futureSystem.Fut[Any]] = reify {
reset { bodyWithFuture.splice }
}
val bodyWithCast = futureSystemOps.castTo[T](bodyWithReset)

AsyncUtils.vprintln(s"CPS-based async transform expands to:\n${bodyWithCast.tree}")
bodyWithCast
}

override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = {
AsyncUtils.vprintln("AsyncBaseWithCPSFallback.asyncImpl")

val analyzer = AsyncAnalysis[c.type](c, this)

if (!analyzer.reportUnsupportedAwaits(body.tree))
super.asyncImpl[T](c)(body) // no unsupported awaits
else
cpsBasedAsyncImpl[T](c)(body) // fallback to CPS
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
*/

package scala.async
package continuations

import scala.language.experimental.macros

import scala.reflect.macros.Context
import scala.concurrent.Future

trait AsyncWithCPSFallback extends AsyncBaseWithCPSFallback with ScalaConcurrentCPSFallback

object AsyncWithCPSFallback extends AsyncWithCPSFallback {

def async[T](body: T) = macro asyncImpl[T]

override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body)
}
21 changes: 21 additions & 0 deletions src/main/scala/scala/async/continuations/CPSBasedAsync.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
*/

package scala.async
package continuations

import scala.language.experimental.macros

import scala.reflect.macros.Context
import scala.concurrent.Future

trait CPSBasedAsync extends CPSBasedAsyncBase with ScalaConcurrentCPSFallback

object CPSBasedAsync extends CPSBasedAsync {

def async[T](body: T) = macro asyncImpl[T]

override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body)

}
21 changes: 21 additions & 0 deletions src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
*/

package scala.async
package continuations

import scala.language.experimental.macros

import scala.reflect.macros.Context
import scala.util.continuations._

/* Specializes `AsyncBaseWithCPSFallback` to always fall back to CPS, yielding a purely CPS-based
* implementation of async/await.
*/
trait CPSBasedAsyncBase extends AsyncBaseWithCPSFallback {

override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] =
super.cpsBasedAsyncImpl[T](c)(body)

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
*/

package scala.async
package continuations

import scala.util.continuations._
import scala.concurrent.{Future, Promise, ExecutionContext}

trait ScalaConcurrentCPSFallback {
self: AsyncBaseWithCPSFallback =>

import ExecutionContext.Implicits.global

lazy val futureSystem = ScalaConcurrentFutureSystem
type FS = ScalaConcurrentFutureSystem.type

/* Fall-back for `await` when it is called at an unsupported position.
*/
override def awaitFallback[T](awaitable: futureSystem.Fut[T]): T @cps[Future[Any]] =
shift {
(k: (T => Future[Any])) =>
val fr = Promise[Any]()
awaitable onComplete {
case tr => fr completeWith k(tr.get)
}
fr.future
}

}
49 changes: 49 additions & 0 deletions src/test/scala/scala/async/run/cps/CPSSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
*/

package scala.async
package run
package cps

import scala.concurrent.{Future, Promise, ExecutionContext, future, Await}
import scala.concurrent.duration._
import scala.async.continuations.CPSBasedAsync._
import scala.util.continuations._

import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.junit.Test

@RunWith(classOf[JUnit4])
class CPSSpec {

import ExecutionContext.Implicits.global

def m1(y: Int): Future[Int] = async {
val f = future { y + 2 }
val f2 = future { y + 3 }
val x1 = await(f)
val x2 = await(f2)
x1 + x2
}

def m2(y: Int): Future[Int] = async {
val f = future { y + 2 }
val res = await(f)
if (y > 0) res + 2
else res - 2
}

@Test
def testCPSFallback() {
val fut1 = m1(10)
val res1 = Await.result(fut1, 2.seconds)
assert(res1 == 25, s"expected 25, got $res1")

val fut2 = m2(10)
val res2 = Await.result(fut2, 2.seconds)
assert(res2 == 14, s"expected 14, got $res2")
}

}