Skip to content

Commit

Permalink
Replace CheckCPSMethodTraverser with additional parameter on transfor…
Browse files Browse the repository at this point in the history
…mer methods

Other fixes:

- remove CPSUtils.allCPSMethods
- add clarifying comment about adding a plus marker to a return expression
  • Loading branch information
phaller committed Nov 2, 2012
1 parent 2c00346 commit 4c5aa9b
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 68 deletions.
1 change: 0 additions & 1 deletion src/compiler/scala/tools/nsc/typechecker/Typers.scala
Expand Up @@ -3182,7 +3182,6 @@ trait Typers extends Modes {
else {
context.enclMethod.returnsSeen = true
val expr1: Tree = typed(expr, EXPRmode | BYVALmode | RETmode, restpt.tpe)

// Warn about returning a value if no value can be returned.
if (restpt.tpe.typeSymbol == UnitClass) {
// The typing in expr1 says expr is Unit (it has already been coerced if
Expand Down
Expand Up @@ -225,6 +225,9 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
} else if (retMode && !hasPlusMarker(tree.tpe) && annotsTree.isEmpty && annotsExpected.nonEmpty) {
// add a marker annotation that will make tree.tpe behave as pt, subtyping wise
// tree will look like having no annotation

// note that we are only adding a plus marker if the method's result type is a CPS type
// (annotsExpected.nonEmpty == cpsParamAnnotation(pt).nonEmpty)
val res = tree modifyType (_ withAnnotations List(newPlusMarker()))
vprintln("adapted annotations (return) of " + tree + " to " + res.tpe)
res
Expand Down
Expand Up @@ -29,8 +29,6 @@ trait CPSUtils {
val shiftUnit0 = newTermName("shiftUnit0")
val shiftUnit = newTermName("shiftUnit")
val shiftUnitR = newTermName("shiftUnitR")
val reset = newTermName("reset")
val reset0 = newTermName("reset0")
}

lazy val MarkerCPSSym = definitions.getClass("scala.util.continuations.cpsSym")
Expand All @@ -49,15 +47,10 @@ trait CPSUtils {
lazy val MethShiftR = definitions.getMember(ModCPS, cpsNames.shiftR)
lazy val MethReify = definitions.getMember(ModCPS, cpsNames.reify)
lazy val MethReifyR = definitions.getMember(ModCPS, cpsNames.reifyR)
lazy val MethReset = definitions.getMember(ModCPS, cpsNames.reset)
lazy val MethReset0 = definitions.getMember(ModCPS, cpsNames.reset0)

lazy val allCPSAnnotations = List(MarkerCPSSym, MarkerCPSTypes, MarkerCPSSynth,
MarkerCPSAdaptPlus, MarkerCPSAdaptMinus)

lazy val allCPSMethods = List(MethShiftUnit, MethShiftUnit0, MethShiftUnitR, MethShift, MethShiftR,
MethReify, MethReifyR, MethReset, MethReset0)

def debuglog(s: => String): Unit = {
//Console.err.println(s)
}
Expand Down
Expand Up @@ -32,39 +32,15 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
implicit val _unit = unit // allow code in CPSUtils.scala to report errors
var cpsAllowed: Boolean = false // detect cps code in places we do not handle (yet)

/* Does not attempt to remove tail returns.
* Only checks whether the method contains returns as well as calls to CPS methods.
*/
class CheckCPSMethodsTraverser extends Traverser {
var cpsMethodsSeen = false
var returnsSeen: Option[Tree] = None
override def traverse(tree: Tree): Unit = tree match {
case Ident(_) | Select(_, _) =>
if (tree.hasSymbol && (allCPSMethods contains tree.symbol))
cpsMethodsSeen = true
super.traverse(tree)
case Return(_) =>
returnsSeen = Some(tree)
case _ =>
super.traverse(tree)
}
}

/* Also checks whether the method calls a CPS method in which case an error is produced
*/
class RemoveTailReturnsTransformer extends Transformer {
var cpsMethodsSeen = false
var returnsSeen: Option[Tree] = None
object RemoveTailReturnsTransformer extends Transformer {
override def transform(tree: Tree): Tree = tree match {
case Block(stms, r @ Return(expr)) =>
returnsSeen = Some(r)
treeCopy.Block(tree, stms, expr)

case Block(stms, expr) =>
treeCopy.Block(tree, stms, transform(expr))

case If(cond, r1 @ Return(thenExpr), r2 @ Return(elseExpr)) =>
returnsSeen = Some(r1)
treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr))

case If(cond, thenExpr, elseExpr) =>
Expand All @@ -77,7 +53,6 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
transform(finalizer))

case CaseDef(pat, guard, r @ Return(expr)) =>
returnsSeen = Some(r)
treeCopy.CaseDef(tree, pat, guard, expr)

case CaseDef(pat, guard, body) =>
Expand All @@ -87,11 +62,6 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
unit.error(tree.pos, "return expressions in CPS code must be in tail position")
tree

case Ident(_) | Select(_, _) =>
if (tree.hasSymbol && (allCPSMethods contains tree.symbol))
cpsMethodsSeen = true
super.transform(tree)

case _ =>
super.transform(tree)
}
Expand All @@ -101,12 +71,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
// support body with single return expression
body match {
case Return(expr) => expr
case _ =>
val tr = new RemoveTailReturnsTransformer
val res = tr.transform(body)
if (tr.returnsSeen.nonEmpty && tr.cpsMethodsSeen)
unit.error(tr.returnsSeen.get.pos, "return expressions not allowed, since method calls CPS method")
res
case _ => RemoveTailReturnsTransformer.transform(body)
}
}

Expand All @@ -130,14 +95,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
atOwner(dd.symbol) {
val rhs =
if (cpsParamTypes(tpt.tpe).nonEmpty) removeTailReturns(rhs0)
else {
val checker = new CheckCPSMethodsTraverser
checker.traverse(rhs0)
if (checker.returnsSeen.nonEmpty && checker.cpsMethodsSeen)
unit.error(checker.returnsSeen.get.pos, "return expressions not allowed, since method calls CPS method")
rhs0
}
val rhs1 = transExpr(rhs, None, getExternalAnswerTypeAnn(tpt.tpe))
else rhs0
val rhs1 = transExpr(rhs, None, getExternalAnswerTypeAnn(tpt.tpe))(getExternalAnswerTypeAnn(tpt.tpe).isDefined)

debuglog("result "+rhs1)
debuglog("result is of type "+rhs1.tpe)
Expand All @@ -162,6 +121,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with

val ext = getExternalAnswerTypeAnn(body.tpe)
val pureBody = getAnswerTypeAnn(body.tpe).isEmpty
implicit val isParentImpure = ext.isDefined

def transformPureMatch(tree: Tree, selector: Tree, cases: List[CaseDef]) = {
val caseVals = cases map { case cd @ CaseDef(pat, guard, body) =>
Expand Down Expand Up @@ -241,16 +201,16 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
}


def transExpr(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo): Tree = {
transTailValue(tree, cpsA, cpsR) match {
def transExpr(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean = false): Tree = {
transTailValue(tree, cpsA, cpsR)(cpsR.isDefined || isAnyParentImpure) match {
case (Nil, b) => b
case (a, b) =>
treeCopy.Block(tree, a,b)
}
}


def transArgList(fun: Tree, args: List[Tree], cpsA: CPSInfo): (List[List[Tree]], List[Tree], CPSInfo) = {
def transArgList(fun: Tree, args: List[Tree], cpsA: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[List[Tree]], List[Tree], CPSInfo) = {
val formals = fun.tpe.paramTypes
val overshoot = args.length - formals.length

Expand All @@ -259,7 +219,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
val (stm,expr) = (for ((a,tp) <- args.zip(formals ::: List.fill(overshoot)(NoType))) yield {
tp match {
case TypeRef(_, ByNameParamClass, List(elemtp)) =>
(Nil, transExpr(a, None, getAnswerTypeAnn(elemtp)))
// note that we're not passing just isAnyParentImpure
(Nil, transExpr(a, None, getAnswerTypeAnn(elemtp))(getAnswerTypeAnn(elemtp).isDefined || isAnyParentImpure))
case _ =>
val (valStm, valExpr, valSpc) = transInlineValue(a, spc)
spc = valSpc
Expand All @@ -271,15 +232,16 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
}


def transValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo): (List[Tree], Tree, CPSInfo) = {
// precondition: cpsR.isDefined "implies" isAnyParentImpure
def transValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree, CPSInfo) = {
// return value: (stms, expr, spc), where spc is CPSInfo after stms but *before* expr
implicit val pos = tree.pos
tree match {
case Block(stms, expr) =>
val (cpsA2, cpsR2) = (cpsA, linearize(cpsA, getAnswerTypeAnn(tree.tpe))) // tbd
// val (cpsA2, cpsR2) = (None, getAnswerTypeAnn(tree.tpe))

val (a, b) = transBlock(stms, expr, cpsA2, cpsR2)
val (a, b) = transBlock(stms, expr, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure)
val tree1 = (treeCopy.Block(tree, a, b)) // no updateSynthFlag here!!!

(Nil, tree1, cpsA)
Expand All @@ -293,8 +255,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
val (cpsA2, cpsR2) = if (hasSynthMarker(tree.tpe))
(spc, linearize(spc, getAnswerTypeAnn(tree.tpe))) else
(None, getAnswerTypeAnn(tree.tpe)) // if no cps in condition, branches must conform to tree.tpe directly
val thenVal = transExpr(thenp, cpsA2, cpsR2)
val elseVal = transExpr(elsep, cpsA2, cpsR2)
val thenVal = transExpr(thenp, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure)
val elseVal = transExpr(elsep, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure)

// check that then and else parts agree (not necessary any more, but left as sanity check)
if (cpsR.isDefined) {
Expand All @@ -314,7 +276,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
else (None, getAnswerTypeAnn(tree.tpe))

val caseVals = cases map { case cd @ CaseDef(pat, guard, body) =>
val bodyVal = transExpr(body, cpsA2, cpsR2)
val bodyVal = transExpr(body, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure)
treeCopy.CaseDef(cd, transform(pat), transform(guard), bodyVal)
}

Expand All @@ -332,7 +294,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
// currentOwner.newMethod(name, tree.pos, Flags.SYNTHETIC) setInfo ldef.symbol.info
val sym = ldef.symbol resetFlag Flags.LABEL
val rhs1 = rhs //new TreeSymSubstituter(List(ldef.symbol), List(sym)).transform(rhs)
val rhsVal = transExpr(rhs1, None, getAnswerTypeAnn(tree.tpe)) changeOwner (currentOwner -> sym)
val rhsVal = transExpr(rhs1, None, getAnswerTypeAnn(tree.tpe))(getAnswerTypeAnn(tree.tpe).isDefined || isAnyParentImpure) changeOwner (currentOwner -> sym)

val stm1 = localTyper.typed(DefDef(sym, rhsVal))
// since virtpatmat does not rely on fall-through, don't call the labels it emits
Expand Down Expand Up @@ -371,6 +333,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
(stms, updateSynthFlag(treeCopy.Assign(tree, transform(lhs), expr)), spc)

case Return(expr0) =>
if (isAnyParentImpure)
unit.error(tree.pos, "return expression not allowed, since method calls CPS method")
val (stms, expr, spc) = transInlineValue(expr0, cpsA)
(stms, updateSynthFlag(treeCopy.Return(tree, expr)), spc)

Expand Down Expand Up @@ -408,7 +372,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
}
}

def transTailValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo): (List[Tree], Tree) = {
// precondition: cpsR.isDefined "implies" isAnyParentImpure
def transTailValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree) = {

val (stms, expr, spc) = transValue(tree, cpsA, cpsR)

Expand Down Expand Up @@ -485,7 +450,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
(stms, expr)
}

def transInlineValue(tree: Tree, cpsA: CPSInfo): (List[Tree], Tree, CPSInfo) = {
def transInlineValue(tree: Tree, cpsA: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree, CPSInfo) = {

val (stms, expr, spc) = transValue(tree, cpsA, None) // never required to be cps

Expand Down Expand Up @@ -513,7 +478,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with



def transInlineStm(stm: Tree, cpsA: CPSInfo): (List[Tree], CPSInfo) = {
def transInlineStm(stm: Tree, cpsA: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], CPSInfo) = {
stm match {

// TODO: what about DefDefs?
Expand Down Expand Up @@ -543,7 +508,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
}
}

def transBlock(stms: List[Tree], expr: Tree, cpsA: CPSInfo, cpsR: CPSInfo): (List[Tree], Tree) = {
// precondition: cpsR.isDefined "implies" isAnyParentImpure
def transBlock(stms: List[Tree], expr: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree) = {
def rec(currStats: List[Tree], currAns: CPSInfo, accum: List[Tree]): (List[Tree], Tree) =
currStats match {
case Nil =>
Expand Down
2 changes: 1 addition & 1 deletion test/files/continuations-neg/t5314-return-reset.check
@@ -1,4 +1,4 @@
t5314-return-reset.scala:14: error: return expressions not allowed, since method calls CPS method
t5314-return-reset.scala:14: error: return expression not allowed, since method calls CPS method
if (rnd.nextInt(100) > 50) return 5 // not allowed, since method is calling `reset`
^
one error found

0 comments on commit 4c5aa9b

Please sign in to comment.