Permalink
Browse files

Replace CheckCPSMethodTraverser with additional parameter on transfor…

…mer methods

Other fixes:

- remove CPSUtils.allCPSMethods
- add clarifying comment about adding a plus marker to a return expression
  • Loading branch information...
phaller committed Aug 8, 2012
1 parent 2c00346 commit 4c5aa9badf1e67e83cc5ea393611dad4a2edb60e
@@ -3182,7 +3182,6 @@ trait Typers extends Modes {
else {
context.enclMethod.returnsSeen = true
val expr1: Tree = typed(expr, EXPRmode | BYVALmode | RETmode, restpt.tpe)
-
// Warn about returning a value if no value can be returned.
if (restpt.tpe.typeSymbol == UnitClass) {
// The typing in expr1 says expr is Unit (it has already been coerced if
@@ -225,6 +225,9 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
} else if (retMode && !hasPlusMarker(tree.tpe) && annotsTree.isEmpty && annotsExpected.nonEmpty) {
// add a marker annotation that will make tree.tpe behave as pt, subtyping wise
// tree will look like having no annotation
+
+ // note that we are only adding a plus marker if the method's result type is a CPS type
+ // (annotsExpected.nonEmpty == cpsParamAnnotation(pt).nonEmpty)
val res = tree modifyType (_ withAnnotations List(newPlusMarker()))
vprintln("adapted annotations (return) of " + tree + " to " + res.tpe)
res
@@ -29,8 +29,6 @@ trait CPSUtils {
val shiftUnit0 = newTermName("shiftUnit0")
val shiftUnit = newTermName("shiftUnit")
val shiftUnitR = newTermName("shiftUnitR")
- val reset = newTermName("reset")
- val reset0 = newTermName("reset0")
}
lazy val MarkerCPSSym = definitions.getClass("scala.util.continuations.cpsSym")
@@ -49,15 +47,10 @@ trait CPSUtils {
lazy val MethShiftR = definitions.getMember(ModCPS, cpsNames.shiftR)
lazy val MethReify = definitions.getMember(ModCPS, cpsNames.reify)
lazy val MethReifyR = definitions.getMember(ModCPS, cpsNames.reifyR)
- lazy val MethReset = definitions.getMember(ModCPS, cpsNames.reset)
- lazy val MethReset0 = definitions.getMember(ModCPS, cpsNames.reset0)
lazy val allCPSAnnotations = List(MarkerCPSSym, MarkerCPSTypes, MarkerCPSSynth,
MarkerCPSAdaptPlus, MarkerCPSAdaptMinus)
- lazy val allCPSMethods = List(MethShiftUnit, MethShiftUnit0, MethShiftUnitR, MethShift, MethShiftR,
- MethReify, MethReifyR, MethReset, MethReset0)
-
def debuglog(s: => String): Unit = {
//Console.err.println(s)
}
@@ -32,39 +32,15 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
implicit val _unit = unit // allow code in CPSUtils.scala to report errors
var cpsAllowed: Boolean = false // detect cps code in places we do not handle (yet)
- /* Does not attempt to remove tail returns.
- * Only checks whether the method contains returns as well as calls to CPS methods.
- */
- class CheckCPSMethodsTraverser extends Traverser {
- var cpsMethodsSeen = false
- var returnsSeen: Option[Tree] = None
- override def traverse(tree: Tree): Unit = tree match {
- case Ident(_) | Select(_, _) =>
- if (tree.hasSymbol && (allCPSMethods contains tree.symbol))
- cpsMethodsSeen = true
- super.traverse(tree)
- case Return(_) =>
- returnsSeen = Some(tree)
- case _ =>
- super.traverse(tree)
- }
- }
-
- /* Also checks whether the method calls a CPS method in which case an error is produced
- */
- class RemoveTailReturnsTransformer extends Transformer {
- var cpsMethodsSeen = false
- var returnsSeen: Option[Tree] = None
+ object RemoveTailReturnsTransformer extends Transformer {
override def transform(tree: Tree): Tree = tree match {
case Block(stms, r @ Return(expr)) =>
- returnsSeen = Some(r)
treeCopy.Block(tree, stms, expr)
case Block(stms, expr) =>
treeCopy.Block(tree, stms, transform(expr))
case If(cond, r1 @ Return(thenExpr), r2 @ Return(elseExpr)) =>
- returnsSeen = Some(r1)
treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr))
case If(cond, thenExpr, elseExpr) =>
@@ -77,7 +53,6 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
transform(finalizer))
case CaseDef(pat, guard, r @ Return(expr)) =>
- returnsSeen = Some(r)
treeCopy.CaseDef(tree, pat, guard, expr)
case CaseDef(pat, guard, body) =>
@@ -87,11 +62,6 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
unit.error(tree.pos, "return expressions in CPS code must be in tail position")
tree
- case Ident(_) | Select(_, _) =>
- if (tree.hasSymbol && (allCPSMethods contains tree.symbol))
- cpsMethodsSeen = true
- super.transform(tree)
-
case _ =>
super.transform(tree)
}
@@ -101,12 +71,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
// support body with single return expression
body match {
case Return(expr) => expr
- case _ =>
- val tr = new RemoveTailReturnsTransformer
- val res = tr.transform(body)
- if (tr.returnsSeen.nonEmpty && tr.cpsMethodsSeen)
- unit.error(tr.returnsSeen.get.pos, "return expressions not allowed, since method calls CPS method")
- res
+ case _ => RemoveTailReturnsTransformer.transform(body)
}
}
@@ -130,14 +95,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
atOwner(dd.symbol) {
val rhs =
if (cpsParamTypes(tpt.tpe).nonEmpty) removeTailReturns(rhs0)
- else {
- val checker = new CheckCPSMethodsTraverser
- checker.traverse(rhs0)
- if (checker.returnsSeen.nonEmpty && checker.cpsMethodsSeen)
- unit.error(checker.returnsSeen.get.pos, "return expressions not allowed, since method calls CPS method")
- rhs0
- }
- val rhs1 = transExpr(rhs, None, getExternalAnswerTypeAnn(tpt.tpe))
+ else rhs0
+ val rhs1 = transExpr(rhs, None, getExternalAnswerTypeAnn(tpt.tpe))(getExternalAnswerTypeAnn(tpt.tpe).isDefined)
debuglog("result "+rhs1)
debuglog("result is of type "+rhs1.tpe)
@@ -162,6 +121,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
val ext = getExternalAnswerTypeAnn(body.tpe)
val pureBody = getAnswerTypeAnn(body.tpe).isEmpty
+ implicit val isParentImpure = ext.isDefined
def transformPureMatch(tree: Tree, selector: Tree, cases: List[CaseDef]) = {
val caseVals = cases map { case cd @ CaseDef(pat, guard, body) =>
@@ -241,16 +201,16 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
}
- def transExpr(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo): Tree = {
- transTailValue(tree, cpsA, cpsR) match {
+ def transExpr(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean = false): Tree = {
+ transTailValue(tree, cpsA, cpsR)(cpsR.isDefined || isAnyParentImpure) match {
case (Nil, b) => b
case (a, b) =>
treeCopy.Block(tree, a,b)
}
}
- def transArgList(fun: Tree, args: List[Tree], cpsA: CPSInfo): (List[List[Tree]], List[Tree], CPSInfo) = {
+ def transArgList(fun: Tree, args: List[Tree], cpsA: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[List[Tree]], List[Tree], CPSInfo) = {
val formals = fun.tpe.paramTypes
val overshoot = args.length - formals.length
@@ -259,7 +219,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
val (stm,expr) = (for ((a,tp) <- args.zip(formals ::: List.fill(overshoot)(NoType))) yield {
tp match {
case TypeRef(_, ByNameParamClass, List(elemtp)) =>
- (Nil, transExpr(a, None, getAnswerTypeAnn(elemtp)))
+ // note that we're not passing just isAnyParentImpure
+ (Nil, transExpr(a, None, getAnswerTypeAnn(elemtp))(getAnswerTypeAnn(elemtp).isDefined || isAnyParentImpure))
case _ =>
val (valStm, valExpr, valSpc) = transInlineValue(a, spc)
spc = valSpc
@@ -271,15 +232,16 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
}
- def transValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo): (List[Tree], Tree, CPSInfo) = {
+ // precondition: cpsR.isDefined "implies" isAnyParentImpure
+ def transValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree, CPSInfo) = {
// return value: (stms, expr, spc), where spc is CPSInfo after stms but *before* expr
implicit val pos = tree.pos
tree match {
case Block(stms, expr) =>
val (cpsA2, cpsR2) = (cpsA, linearize(cpsA, getAnswerTypeAnn(tree.tpe))) // tbd
// val (cpsA2, cpsR2) = (None, getAnswerTypeAnn(tree.tpe))
- val (a, b) = transBlock(stms, expr, cpsA2, cpsR2)
+ val (a, b) = transBlock(stms, expr, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure)
val tree1 = (treeCopy.Block(tree, a, b)) // no updateSynthFlag here!!!
(Nil, tree1, cpsA)
@@ -293,8 +255,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
val (cpsA2, cpsR2) = if (hasSynthMarker(tree.tpe))
(spc, linearize(spc, getAnswerTypeAnn(tree.tpe))) else
(None, getAnswerTypeAnn(tree.tpe)) // if no cps in condition, branches must conform to tree.tpe directly
- val thenVal = transExpr(thenp, cpsA2, cpsR2)
- val elseVal = transExpr(elsep, cpsA2, cpsR2)
+ val thenVal = transExpr(thenp, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure)
+ val elseVal = transExpr(elsep, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure)
// check that then and else parts agree (not necessary any more, but left as sanity check)
if (cpsR.isDefined) {
@@ -314,7 +276,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
else (None, getAnswerTypeAnn(tree.tpe))
val caseVals = cases map { case cd @ CaseDef(pat, guard, body) =>
- val bodyVal = transExpr(body, cpsA2, cpsR2)
+ val bodyVal = transExpr(body, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure)
treeCopy.CaseDef(cd, transform(pat), transform(guard), bodyVal)
}
@@ -332,7 +294,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
// currentOwner.newMethod(name, tree.pos, Flags.SYNTHETIC) setInfo ldef.symbol.info
val sym = ldef.symbol resetFlag Flags.LABEL
val rhs1 = rhs //new TreeSymSubstituter(List(ldef.symbol), List(sym)).transform(rhs)
- val rhsVal = transExpr(rhs1, None, getAnswerTypeAnn(tree.tpe)) changeOwner (currentOwner -> sym)
+ val rhsVal = transExpr(rhs1, None, getAnswerTypeAnn(tree.tpe))(getAnswerTypeAnn(tree.tpe).isDefined || isAnyParentImpure) changeOwner (currentOwner -> sym)
val stm1 = localTyper.typed(DefDef(sym, rhsVal))
// since virtpatmat does not rely on fall-through, don't call the labels it emits
@@ -371,6 +333,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
(stms, updateSynthFlag(treeCopy.Assign(tree, transform(lhs), expr)), spc)
case Return(expr0) =>
+ if (isAnyParentImpure)
+ unit.error(tree.pos, "return expression not allowed, since method calls CPS method")
val (stms, expr, spc) = transInlineValue(expr0, cpsA)
(stms, updateSynthFlag(treeCopy.Return(tree, expr)), spc)
@@ -408,7 +372,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
}
}
- def transTailValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo): (List[Tree], Tree) = {
+ // precondition: cpsR.isDefined "implies" isAnyParentImpure
+ def transTailValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree) = {
val (stms, expr, spc) = transValue(tree, cpsA, cpsR)
@@ -485,7 +450,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
(stms, expr)
}
- def transInlineValue(tree: Tree, cpsA: CPSInfo): (List[Tree], Tree, CPSInfo) = {
+ def transInlineValue(tree: Tree, cpsA: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree, CPSInfo) = {
val (stms, expr, spc) = transValue(tree, cpsA, None) // never required to be cps
@@ -513,7 +478,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
- def transInlineStm(stm: Tree, cpsA: CPSInfo): (List[Tree], CPSInfo) = {
+ def transInlineStm(stm: Tree, cpsA: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], CPSInfo) = {
stm match {
// TODO: what about DefDefs?
@@ -543,7 +508,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
}
}
- def transBlock(stms: List[Tree], expr: Tree, cpsA: CPSInfo, cpsR: CPSInfo): (List[Tree], Tree) = {
+ // precondition: cpsR.isDefined "implies" isAnyParentImpure
+ def transBlock(stms: List[Tree], expr: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree) = {
def rec(currStats: List[Tree], currAns: CPSInfo, accum: List[Tree]): (List[Tree], Tree) =
currStats match {
case Nil =>
@@ -1,4 +1,4 @@
-t5314-return-reset.scala:14: error: return expressions not allowed, since method calls CPS method
+t5314-return-reset.scala:14: error: return expression not allowed, since method calls CPS method
if (rnd.nextInt(100) > 50) return 5 // not allowed, since method is calling `reset`
^
one error found

0 comments on commit 4c5aa9b

Please sign in to comment.