Skip to content

Commit

Permalink
Merge pull request #748 from dotty-staging/add/non-local/returns
Browse files Browse the repository at this point in the history
Implement non-local returns
  • Loading branch information
odersky committed Aug 9, 2015
2 parents 07e24e8 + 694aabd commit 9eb55f1
Show file tree
Hide file tree
Showing 11 changed files with 250 additions and 23 deletions.
2 changes: 2 additions & 0 deletions src/dotty/tools/dotc/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class Compiler {
new ExtensionMethods,
new ExpandSAMs,
new TailRec,
new LiftTry,
new ClassOf),
List(new PatternMatcher,
new ExplicitOuter,
Expand All @@ -68,6 +69,7 @@ class Compiler {
new LazyVals,
new Memoize,
new LinkScala2ImplClasses,
new NonLocalReturns,
new CapturedVars, // capturedVars has a transformUnit: no phases should introduce local mutable vars here
new Constructors,
new FunctionalInterfaces,
Expand Down
14 changes: 12 additions & 2 deletions src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import config.Printers._
import typer.Mode
import collection.mutable
import typer.ErrorReporting._
import transform.Erasure

import scala.annotation.tailrec

Expand Down Expand Up @@ -161,6 +162,10 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
def Bind(sym: TermSymbol, body: Tree)(implicit ctx: Context): Bind =
ta.assignType(untpd.Bind(sym.name, body), sym)

/** A pattern corresponding to `sym: tpe` */
def BindTyped(sym: TermSymbol, tpe: Type)(implicit ctx: Context): Bind =
Bind(sym, Typed(Underscore(tpe), TypeTree(tpe)))

def Alternative(trees: List[Tree])(implicit ctx: Context): Alternative =
ta.assignType(untpd.Alternative(trees), trees)

Expand Down Expand Up @@ -733,9 +738,14 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
tree.select(defn.Any_asInstanceOf).appliedToType(tp)
}

/** `tree.asInstanceOf[tp]` unless tree's type already conforms to `tp` */
/** `tree.asInstanceOf[tp]` (or its box/unbox/cast equivalent when after
* erasure and value and non-value types are mixed),
* unless tree's type already conforms to `tp`.
*/
def ensureConforms(tp: Type)(implicit ctx: Context): Tree =
if (tree.tpe <:< tp) tree else asInstance(tp)
if (tree.tpe <:< tp) tree
else if (!ctx.erasedTypes) asInstance(tp)
else Erasure.Boxing.adaptToType(tree, tp)

/** If inititializer tree is `_', the default value of its type,
* otherwise the tree itself.
Expand Down
37 changes: 18 additions & 19 deletions src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -177,26 +177,25 @@ object Contexts {
/** The new implicit references that are introduced by this scope */
private var implicitsCache: ContextualImplicits = null
def implicits: ContextualImplicits = {
if (implicitsCache == null ) {
val outerImplicits =
if (isImportContext && importInfo.hiddenRoot.exists)
outer.implicits exclude importInfo.hiddenRoot
else
outer.implicits
try
implicitsCache = {
val implicitRefs: List[TermRef] =
if (isClassDefContext) owner.thisType.implicitMembers
else if (isImportContext) importInfo.importedImplicits
else if (isNonEmptyScopeContext) scope.implicitDecls
else Nil
if (implicitRefs.isEmpty) outerImplicits
else new ContextualImplicits(implicitRefs, outerImplicits)(this)
}
catch {
case ex: CyclicReference => implicitsCache = outerImplicits
if (implicitsCache == null )
implicitsCache = {
val implicitRefs: List[TermRef] =
if (isClassDefContext)
try owner.thisType.implicitMembers
catch {
case ex: CyclicReference => Nil
}
else if (isImportContext) importInfo.importedImplicits
else if (isNonEmptyScopeContext) scope.implicitDecls
else Nil
val outerImplicits =
if (isImportContext && importInfo.hiddenRoot.exists)
outer.implicits exclude importInfo.hiddenRoot
else
outer.implicits
if (implicitRefs.isEmpty) outerImplicits
else new ContextualImplicits(implicitRefs, outerImplicits)(this)
}
}
implicitsCache
}

Expand Down
1 change: 1 addition & 0 deletions src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ class Definitions {
lazy val Product_productArity = ProductClass.requiredMethod(nme.productArity)
lazy val Product_productPrefix = ProductClass.requiredMethod(nme.productPrefix)
lazy val LanguageModuleClass = ctx.requiredModule("dotty.language").moduleClass.asClass
lazy val NonLocalReturnControlClass = ctx.requiredClass("scala.runtime.NonLocalReturnControl")

// Annotation base classes
lazy val AnnotationClass = ctx.requiredClass("scala.annotation.Annotation")
Expand Down
6 changes: 5 additions & 1 deletion src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,11 @@ object SymDenotations {
/** The expanded name of this denotation. */
final def expandedName(implicit ctx: Context) =
if (is(ExpandedName) || isConstructor) name
else name.expandedName(initial.asSymDenotation.owner)
else {
def legalize(name: Name): Name = // JVM method names may not contain `<' or `>' characters
if (is(Method)) name.replace('<', '(').replace('>', ')') else name
legalize(name.expandedName(initial.asSymDenotation.owner))
}
// need to use initial owner to disambiguate, as multiple private symbols with the same name
// might have been moved from different origins into the same class

Expand Down
66 changes: 66 additions & 0 deletions src/dotty/tools/dotc/transform/LiftTry.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package dotty.tools.dotc
package transform

import TreeTransforms._
import core.DenotTransformers._
import core.Symbols._
import core.Contexts._
import core.Types._
import core.Flags._
import core.Decorators._
import NonLocalReturns._

/** Lifts try's that might be executed on non-empty expression stacks
* to their own methods. I.e.
*
* try body catch handler
*
* is lifted to
*
* { def liftedTree$n() = try body catch handler; liftedTree$n() }
*/
class LiftTry extends MiniPhase with IdentityDenotTransformer { thisTransform =>
import ast.tpd._

/** the following two members override abstract members in Transform */
val phaseName: String = "liftTry"

val treeTransform = new Transform(needLift = false)
val liftingTransform = new Transform(needLift = true)

class Transform(needLift: Boolean) extends TreeTransform {
def phase = thisTransform

override def prepareForApply(tree: Apply)(implicit ctx: Context) =
if (tree.fun.symbol.is(Label)) this
else liftingTransform

override def prepareForValDef(tree: ValDef)(implicit ctx: Context) =
if (!tree.symbol.exists ||
tree.symbol.isSelfSym ||
tree.symbol.owner == ctx.owner.enclosingMethod) this
else liftingTransform

override def prepareForAssign(tree: Assign)(implicit ctx: Context) =
if (tree.lhs.symbol.maybeOwner == ctx.owner.enclosingMethod) this
else liftingTransform

override def prepareForReturn(tree: Return)(implicit ctx: Context) =
if (!isNonLocalReturn(tree)) this
else liftingTransform

override def prepareForTemplate(tree: Template)(implicit ctx: Context) =
treeTransform

override def transformTry(tree: Try)(implicit ctx: Context, info: TransformerInfo): Tree =
if (needLift) {
ctx.debuglog(i"lifting tree at ${tree.pos}, current owner = ${ctx.owner}")
val fn = ctx.newSymbol(
ctx.owner, ctx.freshName("liftedTree").toTermName, Synthetic | Method,
MethodType(Nil, tree.tpe), coord = tree.pos)
tree.changeOwnerAfter(ctx.owner, fn, thisTransform)
Block(DefDef(fn, tree) :: Nil, ref(fn).appliedToNone)
}
else tree
}
}
92 changes: 92 additions & 0 deletions src/dotty/tools/dotc/transform/NonLocalReturns.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package dotty.tools.dotc
package transform

import core._
import Contexts._, Symbols._, Types._, Flags._, Decorators._, StdNames._, Constants._, Phases._
import TreeTransforms._
import ast.Trees._
import collection.mutable

object NonLocalReturns {
import ast.tpd._
def isNonLocalReturn(ret: Return)(implicit ctx: Context) =
ret.from.symbol != ctx.owner.enclosingMethod || ctx.owner.is(Lazy)
}

/** Implement non-local returns using NonLocalReturnControl exceptions.
*/
class NonLocalReturns extends MiniPhaseTransform { thisTransformer =>
override def phaseName = "nonLocalReturns"

import NonLocalReturns._
import ast.tpd._

override def runsAfter: Set[Class[_ <: Phase]] = Set(classOf[ElimByName])

private def ensureConforms(tree: Tree, pt: Type)(implicit ctx: Context) =
if (tree.tpe <:< pt) tree
else Erasure.Boxing.adaptToType(tree, pt)

/** The type of a non-local return expression with given argument type */
private def nonLocalReturnExceptionType(argtype: Type)(implicit ctx: Context) =
defn.NonLocalReturnControlClass.typeRef.appliedTo(argtype)

/** A hashmap from method symbols to non-local return keys */
private val nonLocalReturnKeys = mutable.Map[Symbol, TermSymbol]()

/** Return non-local return key for given method */
private def nonLocalReturnKey(meth: Symbol)(implicit ctx: Context) =
nonLocalReturnKeys.getOrElseUpdate(meth,
ctx.newSymbol(
meth, ctx.freshName("nonLocalReturnKey").toTermName, Synthetic, defn.ObjectType, coord = meth.pos))

/** Generate a non-local return throw with given return expression from given method.
* I.e. for the method's non-local return key, generate:
*
* throw new NonLocalReturnControl(key, expr)
* todo: maybe clone a pre-existing exception instead?
* (but what to do about exceptions that miss their targets?)
*/
private def nonLocalReturnThrow(expr: Tree, meth: Symbol)(implicit ctx: Context) =
Throw(
New(
defn.NonLocalReturnControlClass.typeRef,
ref(nonLocalReturnKey(meth)) :: expr.ensureConforms(defn.ObjectType) :: Nil))

/** Transform (body, key) to:
*
* {
* val key = new Object()
* try {
* body
* } catch {
* case ex: NonLocalReturnControl =>
* if (ex.key().eq(key)) ex.value().asInstanceOf[T]
* else throw ex
* }
* }
*/
private def nonLocalReturnTry(body: Tree, key: TermSymbol, meth: Symbol)(implicit ctx: Context) = {
val keyDef = ValDef(key, New(defn.ObjectType, Nil))
val nonLocalReturnControl = defn.NonLocalReturnControlClass.typeRef
val ex = ctx.newSymbol(meth, nme.ex, EmptyFlags, nonLocalReturnControl, coord = body.pos)
val pat = BindTyped(ex, nonLocalReturnControl)
val rhs = If(
ref(ex).select(nme.key).appliedToNone.select(nme.eq).appliedTo(ref(key)),
ref(ex).select(nme.value).ensureConforms(meth.info.finalResultType),
Throw(ref(ex)))
val catches = CaseDef(pat, EmptyTree, rhs) :: Nil
val tryCatch = Try(body, catches, EmptyTree)
Block(keyDef :: Nil, tryCatch)
}

override def transformDefDef(tree: DefDef)(implicit ctx: Context, info: TransformerInfo): Tree =
nonLocalReturnKeys.remove(tree.symbol) match {
case Some(key) => cpy.DefDef(tree)(rhs = nonLocalReturnTry(tree.rhs, key, tree.symbol))
case _ => tree
}

override def transformReturn(tree: Return)(implicit ctx: Context, info: TransformerInfo): Tree =
if (isNonLocalReturn(tree)) nonLocalReturnThrow(tree.expr, tree.from.symbol).withPos(tree.pos)
else tree
}
2 changes: 1 addition & 1 deletion test/dotc/tests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class tests extends CompilerTest {
@Test def neg_zoo = compileFile(negDir, "zoo", xerrors = 12)

val negTailcallDir = negDir + "tailcall/"
@Test def neg_tailcall_t1672b = compileFile(negTailcallDir, "t1672b", xerrors = 6)
@Test def neg_tailcall_t1672b = compileFile(negTailcallDir, "t1672b", xerrors = 5)
@Test def neg_tailcall_t3275 = compileFile(negTailcallDir, "t3275", xerrors = 1)
@Test def neg_tailcall_t6574 = compileFile(negTailcallDir, "t6574", xerrors = 2)
@Test def neg_tailcall = compileFile(negTailcallDir, "tailrec", xerrors = 7)
Expand Down
File renamed without changes.
21 changes: 21 additions & 0 deletions tests/run/liftedTry.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
object Test {

def raise(x: Int) = { throw new Exception(s"$x"); 0 }
def handle: Throwable => Int = { case ex: Exception => ex.getMessage().toInt }

val x = try raise(1) catch handle

def foo(x: Int) = {
val y = try raise(x) catch handle
y
}

foo(try 3 catch handle)

def main(args: Array[String]) = {
assert(x == 1)
assert(foo(2) == 2)
assert(foo(try raise(3) catch handle) == 3)
}
}

32 changes: 32 additions & 0 deletions tests/run/nonLocalReturns.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
object Test {

def foo(xs: List[Int]): Int = {
xs.foreach(x => return x)
0
}

def bar(xs: List[Int]): Int = {
lazy val y = if (xs.isEmpty) return -1 else xs.head
y
}

def baz(x: Int): Int =
byName { return -2; 3 }

def byName(x: => Int): Int = x

def bam(): Int = { // no non-local return needed here
val foo = {
return -3
3
}
foo
}

def main(args: Array[String]) = {
assert(foo(List(1, 2, 3)) == 1)
assert(bar(Nil) == -1)
assert(baz(3) == -2)
assert(bam() == -3)
}
}

0 comments on commit 9eb55f1

Please sign in to comment.