Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

SI-5314 - CPS transform of return statement fails

Enable return expressions in CPS code if they are in tail position. Note that tail returns are
only removed in methods that do not call `shift` or `reset` (otherwise, an error is reported).

Addresses the issues pointed out in a previous pull request:
#720

- Addresses all issues mentioned here:
  #720 (comment)

- Move transformation methods to SelectiveANFTransform.scala:
  #720 (comment)

- Do not keep a list of tail returns.

Tests:
- continuations-neg/t5314-missing-result-type.scala
- continuations-neg/t5314-type-error.scala
- continuations-neg/t5314-npe.scala
- continuations-neg/t5314-return-reset.scala
- continuations-run/t5314.scala
- continuations-run/t5314-2.scala
- continuations-run/t5314-3.scala
  • Loading branch information...
commit 2c00346a9756190aef497cabbb6f77ecc25212c8 1 parent 8aeae62
@phaller phaller authored
Showing with 332 additions and 15 deletions.
  1. +5 −1 src/compiler/scala/tools/nsc/typechecker/Modes.scala
  2. +2 −1  src/compiler/scala/tools/nsc/typechecker/Typers.scala
  3. +33 −12 src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala
  4. +7 −0 src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala
  5. +88 −1 src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala
  6. +4 −0 test/files/continuations-neg/t5314-missing-result-type.check
  7. +13 −0 test/files/continuations-neg/t5314-missing-result-type.scala
  8. +4 −0 test/files/continuations-neg/t5314-npe.check
  9. +3 −0  test/files/continuations-neg/t5314-npe.scala
  10. +4 −0 test/files/continuations-neg/t5314-return-reset.check
  11. +21 −0 test/files/continuations-neg/t5314-return-reset.scala
  12. +6 −0 test/files/continuations-neg/t5314-type-error.check
  13. +17 −0 test/files/continuations-neg/t5314-type-error.scala
  14. +5 −0 test/files/continuations-run/t5314-2.check
  15. +44 −0 test/files/continuations-run/t5314-2.scala
  16. +4 −0 test/files/continuations-run/t5314-3.check
  17. +27 −0 test/files/continuations-run/t5314-3.scala
  18. +4 −0 test/files/continuations-run/t5314.check
  19. +41 −0 test/files/continuations-run/t5314.scala
View
6 src/compiler/scala/tools/nsc/typechecker/Modes.scala
@@ -86,6 +86,10 @@ trait Modes {
*/
final val TYPEPATmode = 0x10000
+ /** RETmode is set when we are typing a return expression.
+ */
+ final val RETmode = 0x20000
+
final private val StickyModes = EXPRmode | PATTERNmode | TYPEmode | ALTmode
final def onlyStickyModes(mode: Int) =
@@ -130,4 +134,4 @@ trait Modes {
def modeString(mode: Int): String =
if (mode == 0) "NOmode"
else (modeNameMap filterKeys (bit => inAllModes(mode, bit))).values mkString " "
-}
+}
View
3  src/compiler/scala/tools/nsc/typechecker/Typers.scala
@@ -3181,7 +3181,8 @@ trait Typers extends Modes {
errorTree(tree, enclMethod.owner + " has return statement; needs result type")
else {
context.enclMethod.returnsSeen = true
- val expr1: Tree = typed(expr, EXPRmode | BYVALmode, restpt.tpe)
+ val expr1: Tree = typed(expr, EXPRmode | BYVALmode | RETmode, restpt.tpe)
+
// Warn about returning a value if no value can be returned.
if (restpt.tpe.typeSymbol == UnitClass) {
// The typing in expr1 says expr is Unit (it has already been coerced if
View
45 src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala
@@ -154,10 +154,8 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
if ((mode & global.analyzer.EXPRmode) != 0) {
if ((annots1 corresponds annots2)(_.atp <:< _.atp)) {
vprintln("already same, can't adapt further")
- return false
- }
-
- if (annots1.isEmpty && !annots2.isEmpty && ((mode & global.analyzer.BYVALmode) == 0)) {
+ false
+ } else if (annots1.isEmpty && !annots2.isEmpty && ((mode & global.analyzer.BYVALmode) == 0)) {
//println("can adapt annotations? " + tree + " / " + tree.tpe + " / " + Integer.toHexString(mode) + " / " + pt)
if (!hasPlusMarker(tree.tpe)) {
// val base = tree.tpe <:< removeAllCPSAnnotations(pt)
@@ -167,17 +165,26 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
// TBD: use same or not?
//if (same) {
vprintln("yes we can!! (unit)")
- return true
+ true
//}
- }
- } else if (!annots1.isEmpty && ((mode & global.analyzer.BYVALmode) != 0)) {
- if (!hasMinusMarker(tree.tpe)) {
+ } else false
+ } else if (!hasPlusMarker(tree.tpe) && annots1.isEmpty && !annots2.isEmpty && ((mode & global.analyzer.RETmode) != 0)) {
+ vprintln("checking enclosing method's result type without annotations")
+ tree.tpe <:< pt.withoutAnnotations
+ } else if (!hasMinusMarker(tree.tpe) && !annots1.isEmpty && ((mode & global.analyzer.BYVALmode) != 0)) {
+ val optCpsTypes: Option[(Type, Type)] = cpsParamTypes(tree.tpe)
+ val optExpectedCpsTypes: Option[(Type, Type)] = cpsParamTypes(pt)
+ if (optCpsTypes.isEmpty || optExpectedCpsTypes.isEmpty) {
vprintln("yes we can!! (byval)")
- return true
+ true
+ } else { // check cps param types
+ val cpsTpes = optCpsTypes.get
+ val cpsPts = optExpectedCpsTypes.get
+ // class cpsParam[-B,+C], therefore:
+ cpsPts._1 <:< cpsTpes._1 && cpsTpes._2 <:< cpsPts._2
}
- }
- }
- false
+ } else false
+ } else false
}
override def adaptAnnotations(tree: Tree, mode: Int, pt: Type): Tree = {
@@ -188,6 +195,7 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
val patMode = (mode & global.analyzer.PATTERNmode) != 0
val exprMode = (mode & global.analyzer.EXPRmode) != 0
val byValMode = (mode & global.analyzer.BYVALmode) != 0
+ val retMode = (mode & global.analyzer.RETmode) != 0
val annotsTree = cpsParamAnnotation(tree.tpe)
val annotsExpected = cpsParamAnnotation(pt)
@@ -214,6 +222,12 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
val res = tree modifyType addMinusMarker
vprintln("adapted annotations (by val) of " + tree + " to " + res.tpe)
res
+ } else if (retMode && !hasPlusMarker(tree.tpe) && annotsTree.isEmpty && annotsExpected.nonEmpty) {
+ // add a marker annotation that will make tree.tpe behave as pt, subtyping wise
+ // tree will look like having no annotation
+ val res = tree modifyType (_ withAnnotations List(newPlusMarker()))
+ vprintln("adapted annotations (return) of " + tree + " to " + res.tpe)
+ res
} else tree
}
@@ -469,6 +483,13 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
}
tpe
+ case ret @ Return(expr) =>
+ // only change type if this return will (a) be removed (in tail position) or (b) cause
+ // an error (not in tail position)
+ if (expr.tpe != null && hasPlusMarker(expr.tpe))
+ ret setType expr.tpe
+ ret.tpe
+
case _ =>
tpe
}
View
7 src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala
@@ -29,6 +29,8 @@ trait CPSUtils {
val shiftUnit0 = newTermName("shiftUnit0")
val shiftUnit = newTermName("shiftUnit")
val shiftUnitR = newTermName("shiftUnitR")
+ val reset = newTermName("reset")
+ val reset0 = newTermName("reset0")
}
lazy val MarkerCPSSym = definitions.getClass("scala.util.continuations.cpsSym")
@@ -47,10 +49,15 @@ trait CPSUtils {
lazy val MethShiftR = definitions.getMember(ModCPS, cpsNames.shiftR)
lazy val MethReify = definitions.getMember(ModCPS, cpsNames.reify)
lazy val MethReifyR = definitions.getMember(ModCPS, cpsNames.reifyR)
+ lazy val MethReset = definitions.getMember(ModCPS, cpsNames.reset)
+ lazy val MethReset0 = definitions.getMember(ModCPS, cpsNames.reset0)
lazy val allCPSAnnotations = List(MarkerCPSSym, MarkerCPSTypes, MarkerCPSSynth,
MarkerCPSAdaptPlus, MarkerCPSAdaptMinus)
+ lazy val allCPSMethods = List(MethShiftUnit, MethShiftUnit0, MethShiftUnitR, MethShift, MethShiftR,
+ MethReify, MethReifyR, MethReset, MethReset0)
+
def debuglog(s: => String): Unit = {
//Console.err.println(s)
}
View
89 src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala
@@ -32,6 +32,84 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
implicit val _unit = unit // allow code in CPSUtils.scala to report errors
var cpsAllowed: Boolean = false // detect cps code in places we do not handle (yet)
+ /* Does not attempt to remove tail returns.
+ * Only checks whether the method contains returns as well as calls to CPS methods.
+ */
+ class CheckCPSMethodsTraverser extends Traverser {
+ var cpsMethodsSeen = false
+ var returnsSeen: Option[Tree] = None
+ override def traverse(tree: Tree): Unit = tree match {
+ case Ident(_) | Select(_, _) =>
+ if (tree.hasSymbol && (allCPSMethods contains tree.symbol))
+ cpsMethodsSeen = true
+ super.traverse(tree)
+ case Return(_) =>
+ returnsSeen = Some(tree)
+ case _ =>
+ super.traverse(tree)
+ }
+ }
+
+ /* Also checks whether the method calls a CPS method in which case an error is produced
+ */
+ class RemoveTailReturnsTransformer extends Transformer {
+ var cpsMethodsSeen = false
+ var returnsSeen: Option[Tree] = None
+ override def transform(tree: Tree): Tree = tree match {
+ case Block(stms, r @ Return(expr)) =>
+ returnsSeen = Some(r)
+ treeCopy.Block(tree, stms, expr)
+
+ case Block(stms, expr) =>
+ treeCopy.Block(tree, stms, transform(expr))
+
+ case If(cond, r1 @ Return(thenExpr), r2 @ Return(elseExpr)) =>
+ returnsSeen = Some(r1)
+ treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr))
+
+ case If(cond, thenExpr, elseExpr) =>
+ treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr))
+
+ case Try(block, catches, finalizer) =>
+ treeCopy.Try(tree,
+ transform(block),
+ (catches map (t => transform(t))).asInstanceOf[List[CaseDef]],
+ transform(finalizer))
+
+ case CaseDef(pat, guard, r @ Return(expr)) =>
+ returnsSeen = Some(r)
+ treeCopy.CaseDef(tree, pat, guard, expr)
+
+ case CaseDef(pat, guard, body) =>
+ treeCopy.CaseDef(tree, pat, guard, transform(body))
+
+ case Return(_) =>
+ unit.error(tree.pos, "return expressions in CPS code must be in tail position")
+ tree
+
+ case Ident(_) | Select(_, _) =>
+ if (tree.hasSymbol && (allCPSMethods contains tree.symbol))
+ cpsMethodsSeen = true
+ super.transform(tree)
+
+ case _ =>
+ super.transform(tree)
+ }
+ }
+
+ def removeTailReturns(body: Tree): Tree = {
+ // support body with single return expression
+ body match {
+ case Return(expr) => expr
+ case _ =>
+ val tr = new RemoveTailReturnsTransformer
+ val res = tr.transform(body)
+ if (tr.returnsSeen.nonEmpty && tr.cpsMethodsSeen)
+ unit.error(tr.returnsSeen.get.pos, "return expressions not allowed, since method calls CPS method")
+ res
+ }
+ }
+
override def transform(tree: Tree): Tree = {
if (!cpsEnabled) return tree
@@ -46,10 +124,19 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
// this would cause infinite recursion. But we could remove the
// ValDef case here.
- case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs) =>
+ case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs0) =>
debuglog("transforming " + dd.symbol)
atOwner(dd.symbol) {
+ val rhs =
+ if (cpsParamTypes(tpt.tpe).nonEmpty) removeTailReturns(rhs0)
+ else {
+ val checker = new CheckCPSMethodsTraverser
+ checker.traverse(rhs0)
+ if (checker.returnsSeen.nonEmpty && checker.cpsMethodsSeen)
+ unit.error(checker.returnsSeen.get.pos, "return expressions not allowed, since method calls CPS method")
+ rhs0
+ }
val rhs1 = transExpr(rhs, None, getExternalAnswerTypeAnn(tpt.tpe))
debuglog("result "+rhs1)
View
4 test/files/continuations-neg/t5314-missing-result-type.check
@@ -0,0 +1,4 @@
+t5314-missing-result-type.scala:6: error: method bar has return statement; needs result type
+ def bar(x:Int) = return foo(x)
+ ^
+one error found
View
13 test/files/continuations-neg/t5314-missing-result-type.scala
@@ -0,0 +1,13 @@
+import scala.util.continuations._
+
+object Test extends App {
+ def foo(x:Int): Int @cps[Int] = x
+
+ def bar(x:Int) = return foo(x)
+
+ reset {
+ val res = bar(8)
+ println(res)
+ res
+ }
+}
View
4 test/files/continuations-neg/t5314-npe.check
@@ -0,0 +1,4 @@
+t5314-npe.scala:2: error: method bar has return statement; needs result type
+ def bar(x:Int) = { return x; x } // NPE
+ ^
+one error found
View
3  test/files/continuations-neg/t5314-npe.scala
@@ -0,0 +1,3 @@
+object Test extends App {
+ def bar(x:Int) = { return x; x } // NPE
+}
View
4 test/files/continuations-neg/t5314-return-reset.check
@@ -0,0 +1,4 @@
+t5314-return-reset.scala:14: error: return expressions not allowed, since method calls CPS method
+ if (rnd.nextInt(100) > 50) return 5 // not allowed, since method is calling `reset`
+ ^
+one error found
View
21 test/files/continuations-neg/t5314-return-reset.scala
@@ -0,0 +1,21 @@
+import scala.util.continuations._
+import scala.util.Random
+
+object Test extends App {
+ val rnd = new Random
+
+ def foo(x: Int): Int @cps[Int] = shift { k => k(x) }
+
+ def bar(x: Int): Int @cps[Int] = return foo(x)
+
+ def caller(): Int = {
+ val v: Int = reset {
+ val res: Int = bar(8)
+ if (rnd.nextInt(100) > 50) return 5 // not allowed, since method is calling `reset`
+ 42
+ }
+ v
+ }
+
+ caller()
+}
View
6 test/files/continuations-neg/t5314-type-error.check
@@ -0,0 +1,6 @@
+t5314-type-error.scala:7: error: type mismatch;
+ found : Int @util.continuations.package.cps[Int]
+ required: Int @util.continuations.package.cps[String]
+ def bar(x:Int): Int @cps[String] = return foo(x)
+ ^
+one error found
View
17 test/files/continuations-neg/t5314-type-error.scala
@@ -0,0 +1,17 @@
+import scala.util.continuations._
+
+object Test extends App {
+ def foo(x:Int): Int @cps[Int] = shift { k => k(x) }
+
+ // should be a type error
+ def bar(x:Int): Int @cps[String] = return foo(x)
+
+ def caller(): Unit = {
+ val v: String = reset {
+ val res: Int = bar(8)
+ "hello"
+ }
+ }
+
+ caller()
+}
View
5 test/files/continuations-run/t5314-2.check
@@ -0,0 +1,5 @@
+8
+hi
+8
+from try
+8
View
44 test/files/continuations-run/t5314-2.scala
@@ -0,0 +1,44 @@
+import scala.util.continuations._
+
+class ReturnRepro {
+ def s1: Int @cps[Any] = shift { k => k(5) }
+ def caller = reset { println(p(3)) }
+ def caller2 = reset { println(p2(3)) }
+ def caller3 = reset { println(p3(3)) }
+
+ def p(i: Int): Int @cps[Any] = {
+ val v= s1 + 3
+ return v
+ }
+
+ def p2(i: Int): Int @cps[Any] = {
+ val v = s1 + 3
+ if (v > 0) {
+ println("hi")
+ return v
+ } else {
+ println("hi")
+ return 8
+ }
+ }
+
+ def p3(i: Int): Int @cps[Any] = {
+ val v = s1 + 3
+ try {
+ println("from try")
+ return v
+ } catch {
+ case e: Exception =>
+ println("from catch")
+ return 7
+ }
+ }
+
+}
+
+object Test extends App {
+ val repro = new ReturnRepro
+ repro.caller
+ repro.caller2
+ repro.caller3
+}
View
4 test/files/continuations-run/t5314-3.check
@@ -0,0 +1,4 @@
+enter return expr
+8
+hi
+8
View
27 test/files/continuations-run/t5314-3.scala
@@ -0,0 +1,27 @@
+import scala.util.continuations._
+
+class ReturnRepro {
+ def s1: Int @cpsParam[Any, Unit] = shift { k => k(5) }
+ def caller = reset { println(p(3)) }
+ def caller2 = reset { println(p2(3)) }
+
+ def p(i: Int): Int @cpsParam[Unit, Any] = {
+ val v= s1 + 3
+ return { println("enter return expr"); v }
+ }
+
+ def p2(i: Int): Int @cpsParam[Unit, Any] = {
+ val v = s1 + 3
+ if (v > 0) {
+ return { println("hi"); v }
+ } else {
+ return { println("hi"); 8 }
+ }
+ }
+}
+
+object Test extends App {
+ val repro = new ReturnRepro
+ repro.caller
+ repro.caller2
+}
View
4 test/files/continuations-run/t5314.check
@@ -0,0 +1,4 @@
+8
+hi
+8
+8
View
41 test/files/continuations-run/t5314.scala
@@ -0,0 +1,41 @@
+import scala.util.continuations._
+
+class ReturnRepro {
+ def s1: Int @cpsParam[Any, Unit] = shift { k => k(5) }
+ def caller = reset { println(p(3)) }
+ def caller2 = reset { println(p2(3)) }
+
+ def p(i: Int): Int @cpsParam[Unit, Any] = {
+ val v= s1 + 3
+ return v
+ }
+
+ def p2(i: Int): Int @cpsParam[Unit, Any] = {
+ val v = s1 + 3
+ if (v > 0) {
+ println("hi")
+ return v
+ } else {
+ println("hi")
+ return 8
+ }
+ }
+}
+
+object Test extends App {
+ def foo(x:Int): Int @cps[Int] = shift { k => k(x) }
+
+ def bar(x:Int): Int @cps[Int] = return foo(x)
+
+ def nocps(x: Int): Int = { return x; x }
+
+ val repro = new ReturnRepro
+ repro.caller
+ repro.caller2
+
+ reset {
+ val res = bar(8)
+ println(res)
+ res
+ }
+}
Please sign in to comment.
Something went wrong with that request. Please try again.