Skip to content

Commit

Permalink
Refactor detection of ill nested await calls
Browse files Browse the repository at this point in the history
We can defer the check until the current location of the "fallback"
check without sacrificing the specific error messages.
  • Loading branch information
retronym committed Jun 18, 2020
1 parent ae91a3c commit 89f48d2
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 30 deletions.
52 changes: 31 additions & 21 deletions src/compiler/scala/tools/nsc/transform/async/AnfTransform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,26 +51,26 @@ private[async] trait AnfTransform extends TransformUtils {
curTree = tree
val treeContainsAwait = containsAwait(tree)
tree match {
case _: ClassDef | _: ModuleDef | _: Function | _: DefDef =>
case _: ClassDef | _: ModuleDef | _: Function | _: DefDef =>
tree
case _ if !treeContainsAwait =>
case _ if !treeContainsAwait =>
tree
case Apply(sel @ Select(fun, _), arg :: Nil) if isBooleanAnd(sel.symbol) && containsAwait(arg) =>
case Apply(sel@Select(fun, _), arg :: Nil) if isBooleanAnd(sel.symbol) && containsAwait(arg) =>
transform(treeCopy.If(tree, fun, arg, literalBool(false)))
case Apply(sel @ Select(fun, _), arg :: Nil) if isBooleanOr(sel.symbol) && containsAwait(arg) =>
case Apply(sel@Select(fun, _), arg :: Nil) if isBooleanOr(sel.symbol) && containsAwait(arg) =>
transform(treeCopy.If(tree, fun, literalBool(true), arg))
case Apply(fun, args) =>
val lastAwaitArgIndex: Int = args.lastIndexWhere(containsAwait)
val simpleFun = transform(fun)
var i = 0
val argExprss = map2(args, fun.symbol.paramss.head) { (arg: Tree, param: Symbol) =>
case Apply(fun, args) =>
val lastAwaitArgIndex: RunId = args.lastIndexWhere(containsAwait)
val simpleFun = transform(fun)
var i = 0
val argExprss = map2(args, fun.symbol.paramss.head) { (arg: Tree, param: Symbol) =>
transform(arg) match {
case expr1 =>
val argName = param.name.toTermName
val argName = param.name.toTermName
// No need to extract the argument into a val if is non-side-effecting or if we are beyond the final
// argument containing an `await` calls.
val elideVal = treeInfo.isExprSafeToInline(expr1) || lastAwaitArgIndex < 0 || i > lastAwaitArgIndex || !treeContainsAwait
val result = if (elideVal) {
val result = if (elideVal) {
localTyper.typed(expr1, arg.tpe) // Adapt () to BoxedUnit
} else {
if (isUnitType(expr1.tpe)) {
Expand All @@ -86,11 +86,11 @@ private[async] trait AnfTransform extends TransformUtils {
result
}
}
val simpleApply = treeCopy.Apply(tree, simpleFun, argExprss)
val simpleApply = treeCopy.Apply(tree, simpleFun, argExprss)
simpleApply.attachments.remove[ContainsAwait.type]
if (isAwait(fun)) {
val valDef = defineVal(transformState.name.await(), treeCopy.Apply(tree, fun, argExprss), tree.pos)
val ref = gen.mkAttributedStableRef(valDef.symbol).setType(tree.tpe)
val ref = gen.mkAttributedStableRef(valDef.symbol).setType(tree.tpe)
currentStats += valDef
atPos(tree.pos)(ref)
} else {
Expand Down Expand Up @@ -119,7 +119,7 @@ private[async] trait AnfTransform extends TransformUtils {
eliminateMatchEndLabelParameter(tree.pos, ts).foreach(t => flattenBlock(t)(currentStats += _)),
onTail = (ts: List[Tree]) =>
ts.foreach(t => flattenBlock(t)(currentStats += _))
)
)

// However, we let `onTail` add the expr to `currentStats` (that was more efficient than using `ts.dropRight(1).foreach(addToStats)`)
// Compensate by removing it from the buffer and returning the expr.
Expand All @@ -129,10 +129,10 @@ private[async] trait AnfTransform extends TransformUtils {
case ld: LabelDef if ld.tpe.typeSymbol == definitions.BoxedUnitClass =>
currentStats += ld
literalBoxedUnit
case ld: LabelDef if ld.tpe.typeSymbol == definitions.UnitClass =>
case ld: LabelDef if ld.tpe.typeSymbol == definitions.UnitClass =>
currentStats += ld
literalUnit
case expr => expr
case expr => expr
}

case ValDef(mods, name, tpt, rhs) => atOwner(tree.symbol) {
Expand Down Expand Up @@ -161,7 +161,7 @@ private[async] trait AnfTransform extends TransformUtils {
case If(cond, thenp, elsep) =>
val needsResultVar = (containsAwait(thenp) || containsAwait(elsep))
transformMatchOrIf(tree, needsResultVar, transformState.name.ifRes) { varSym =>
val condExpr = transform(cond)
val condExpr = transform(cond)
val thenBlock = transformNewControlFlowBlock(thenp)
val elseBlock = transformNewControlFlowBlock(elsep)
treeCopy.If(tree, condExpr, pushAssignmentIntoExpr(varSym, thenBlock), pushAssignmentIntoExpr(varSym, elseBlock))
Expand All @@ -170,19 +170,29 @@ private[async] trait AnfTransform extends TransformUtils {
case Match(scrut, cases) =>
val needResultVar = cases.exists(containsAwait)
transformMatchOrIf(tree, needResultVar, transformState.name.matchRes) { varSym =>
val scrutExpr = transform(scrut)
val scrutExpr = transform(scrut)
val casesWithAssign = cases map {
case cd@CaseDef(pat, guard, body) =>
assignUnitType(treeCopy.CaseDef(cd, pat, transformNewControlFlowBlock(guard), pushAssignmentIntoExpr(varSym, transformNewControlFlowBlock(body))))
}
treeCopy.Match(tree, scrutExpr, casesWithAssign)
}

case ld @ LabelDef(name, params, rhs) =>
case ld@LabelDef(name, params, rhs) =>
treeCopy.LabelDef(tree, name, params, transformNewControlFlowBlock(rhs))
case t @ Typed(expr, tpt) =>
case t@Typed(expr, tpt) =>
transform(expr).setType(t.tpe)
case _ =>
case Try(body, catches, finalizer) =>
// This gets reported in ExprBuilder as an unsupported use of await. We still need to
// have _some_ non-default transform here make all cases in test/async/neg/ill-nested-await.check pass.
//
// TODO Create a result variable for try expression.
// Model exceptional control flow in ExprBuilder and remove this restriction.
treeCopy.Try(tree,
transformNewControlFlowBlock(body),
catches.mapConserve(cd => transformNewControlFlowBlock(cd).asInstanceOf[CaseDef]),
transformNewControlFlowBlock(finalizer))
case _ =>
super.transform(tree)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import scala.collection.mutable
import scala.tools.nsc.transform.{Transform, TypingTransformers}
import scala.reflect.internal.util.SourceFile

abstract class AsyncPhase extends Transform with TypingTransformers with AnfTransform with AsyncAnalysis with Lifter with LiveVariables {
abstract class AsyncPhase extends Transform with TypingTransformers with AnfTransform with Lifter with LiveVariables {
self =>
import global._

Expand Down Expand Up @@ -146,7 +146,6 @@ abstract class AsyncPhase extends Transform with TypingTransformers with AnfTran
// We mark whether each sub-tree of `asyncBody` that do or do not contain an await in thus pre-processing pass.
// The ANF transform can then efficiently query this to selectively transform the tree.
markContainsAwait(asyncBody)
reportUnsupportedAwaits(asyncBody)

// Transform to A-normal form:
// - no await calls in qualifiers or arguments,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ package scala.tools.nsc.transform.async
import scala.collection.mutable
import scala.collection.mutable.ListBuffer

trait ExprBuilder extends TransformUtils {

trait ExprBuilder extends TransformUtils with AsyncAnalysis {
import global._

private def stateAssigner = currentTransformState.stateAssigner
Expand Down Expand Up @@ -347,11 +346,7 @@ trait ExprBuilder extends TransformUtils {
}

private def checkForUnsupportedAwait(tree: Tree) = if (containsAwait(tree)) {
tree.foreach {
case tree: RefTree if isAwait(tree) =>
global.reporter.error(tree.pos, "await must not be used in this position")
case _ =>
}
reportUnsupportedAwaits(tree)
}

/** Copy these states into the current block builder's async stats updating the open state builder's
Expand Down

0 comments on commit 89f48d2

Please sign in to comment.