Permalink
Browse files

Rewrite TailCalls for performance and immutability.

While logging symbols created after typer, I discovered that
TailCalls was far and away the largest creator. It turns out
this was due to a bug where thousands of labels were eagerly
created during tail call analysis, even if the method weren't
tail recursive and no label would ever be required.

This commit shaves 10% off the total number of method symbol
creations (compiling quick.lib drops from 88K to 80K.)
  • Loading branch information...
1 parent 86651c1 commit 1ce4ecd800d2db38d73de103d09060e871520369 @paulp paulp committed Apr 17, 2013
Showing with 111 additions and 116 deletions.
  1. +111 −116 src/compiler/scala/tools/nsc/transform/TailCalls.scala
View
227 src/compiler/scala/tools/nsc/transform/TailCalls.scala
@@ -87,98 +87,112 @@ abstract class TailCalls extends Transform {
* </p>
*/
class TailCallElimination(unit: CompilationUnit) extends Transformer {
- private val defaultReason = "it contains a recursive call not in tail position"
+ private def defaultReason = "it contains a recursive call not in tail position"
+ private val failPositions = perRunCaches.newMap[TailContext, Position]()
+ private val failReasons = perRunCaches.newMap[TailContext, String]()
+ private def tailrecFailure(ctx: TailContext) {
+ val method = ctx.method
+ val failReason = failReasons.getOrElse(ctx, defaultReason)
+ val failPos = failPositions.getOrElse(ctx, ctx.methodPos)
+
+ unit.error(failPos, s"could not optimize @tailrec annotated $method: $failReason")
+ }
/** Has the label been accessed? Then its symbol is in this set. */
- private val accessed = new scala.collection.mutable.HashSet[Symbol]()
+ private val accessed = perRunCaches.newSet[Symbol]()
// `accessed` was stored as boolean in the current context -- this is no longer tenable
// with jumps to labels in tailpositions now considered in tailposition,
// a downstream context may access the label, and the upstream one will be none the wiser
// this is necessary because tail-calls may occur in places where syntactically they seem impossible
// (since we now consider jumps to labels that are in tailposition, such as matchEnd(x) {x})
+ sealed trait TailContext {
+ def method: Symbol // current method
+ def tparams: List[Symbol] // type parameters
+ def methodPos: Position // default position for failure reporting
+ def tailPos: Boolean // context is in tail position
+ def label: Symbol // new label, tail call target
+ def tailLabels: Set[Symbol]
+
+ def enclosingType = method.enclClass.typeOfThis
+ def isEligible = method.isEffectivelyFinal
+ def isMandatory = method.hasAnnotation(TailrecClass)
+ def isTransformed = isEligible && accessed(label)
+
+ def newThis(pos: Position) = {
+ def msg = "Creating new `this` during tailcalls\n method: %s\n current class: %s".format(
+ method.ownerChain.mkString(" -> "),
+ currentClass.ownerChain.mkString(" -> ")
+ )
+ logResult(msg)(method.newValue(nme.THIS, pos, SYNTHETIC) setInfo currentClass.typeOfThis)
+ }
+ override def toString = s"${method.name} tparams=$tparams tailPos=$tailPos label=$label label info=${label.info}"
+ }
- class Context() {
- /** The current method */
- var method: Symbol = NoSymbol
-
- // symbols of label defs in this method that are in tail position
- var tailLabels: Set[Symbol] = Set()
-
- /** The current tail-call label */
- var label: Symbol = NoSymbol
-
- /** The expected type arguments of self-recursive calls */
- var tparams: List[Symbol] = Nil
-
- /** Tells whether we are in a (possible) tail position */
- var tailPos = false
-
- /** The reason this method could not be optimized. */
- var failReason = defaultReason
- var failPos = method.pos
+ object EmptyTailContext extends TailContext {
+ def method = NoSymbol
+ def tparams = Nil
+ def methodPos = NoPosition
+ def tailPos = false
+ def label = NoSymbol
+ def tailLabels = Set.empty[Symbol]
+ }
- def this(that: Context) = {
- this()
- this.method = that.method
- this.tparams = that.tparams
- this.tailPos = that.tailPos
- this.failPos = that.failPos
- this.label = that.label
- this.tailLabels = that.tailLabels
+ class DefDefTailContext(dd: DefDef) extends TailContext {
+ def method = dd.symbol
+ def tparams = dd.tparams map (_.symbol)
+ def methodPos = dd.pos
+ def tailPos = true
+
+ lazy val label = mkLabel()
+ lazy val tailLabels = {
+ // labels are local to a method, so only traverse the rhs of a defdef
+ val collector = new TailPosLabelsTraverser
+ collector traverse dd.rhs
+ collector.tailLabels.toSet
}
- def this(dd: DefDef) {
- this()
- this.method = dd.symbol
- this.tparams = dd.tparams map (_.symbol)
- this.tailPos = true
- this.failPos = dd.pos
-
- /* Create a new method symbol for the current method and store it in
- * the label field.
- */
- this.label = {
- val label = method.newLabel(newTermName("_" + method.name), method.pos)
- val thisParam = method.newSyntheticValueParam(currentClass.typeOfThis)
- label setInfo MethodType(thisParam :: method.tpe.params, method.tpe.finalResultType)
- }
+
+ private def mkLabel() = {
+ val label = method.newLabel(newTermName("_" + method.name), method.pos)
+ val thisParam = method.newSyntheticValueParam(currentClass.typeOfThis)
+ label setInfo MethodType(thisParam :: method.tpe.params, method.tpe.finalResultType)
if (isEligible)
label substInfo (method.tpe.typeParams, tparams)
- }
-
- def enclosingType = method.enclClass.typeOfThis
- def isEligible = method.isEffectivelyFinal
- // @tailrec annotation indicates mandatory transformation
- def isMandatory = method.hasAnnotation(TailrecClass)
- def isTransformed = isEligible && accessed(label)
- def tailrecFailure() = unit.error(failPos, "could not optimize @tailrec annotated " + method + ": " + failReason)
- def newThis(pos: Position) = logResult("Creating new `this` during tailcalls\n method: %s\n current class: %s".format(
- method.ownerChain.mkString(" -> "), currentClass.ownerChain.mkString(" -> "))) {
- method.newValue(nme.THIS, pos, SYNTHETIC) setInfo currentClass.typeOfThis
+ label
}
-
- override def toString(): String = (
- "" + method.name + " tparams: " + tparams + " tailPos: " + tailPos +
- " Label: " + label + " Label type: " + label.info
- )
+ private def isRecursiveCall(t: Tree) = {
+ val receiver = t.symbol
+
+ ( (receiver != null)
+ && receiver.isMethod
+ && (method.name == receiver.name)
+ && (method.enclClass isSubClass receiver.enclClass)
+ )
+ }
+ def containsRecursiveCall(t: Tree) = t exists isRecursiveCall
}
-
- private var ctx: Context = new Context()
- private def noTailContext() = {
- val t = new Context(ctx)
- t.tailPos = false
- t
+ class ClonedTailContext(that: TailContext, override val tailPos: Boolean) extends TailContext {
+ def method = that.method
+ def tparams = that.tparams
+ def methodPos = that.methodPos
+ def tailLabels = that.tailLabels
+ def label = that.label
}
+ private var ctx: TailContext = EmptyTailContext
+ private def noTailContext() = new ClonedTailContext(ctx, tailPos = false)
+ private def yesTailContext() = new ClonedTailContext(ctx, tailPos = true)
+
/** Rewrite this tree to contain no tail recursive calls */
- def transform(tree: Tree, nctx: Context): Tree = {
+ def transform(tree: Tree, nctx: TailContext): Tree = {
val saved = ctx
ctx = nctx
try transform(tree)
finally this.ctx = saved
}
+ def yesTailTransform(tree: Tree): Tree = transform(tree, yesTailContext())
def noTailTransform(tree: Tree): Tree = transform(tree, noTailContext())
def noTailTransforms(trees: List[Tree]) = {
val nctx = noTailContext()
@@ -192,7 +206,6 @@ abstract class TailCalls extends Transform {
case Select(qual, _) => qual
case _ => EmptyTree
}
-
def receiverIsSame = ctx.enclosingType.widen =:= receiver.tpe.widen
def receiverIsSuper = ctx.enclosingType.widen <:< receiver.tpe.widen
def isRecursiveCall = (ctx.method eq fun.symbol) && ctx.tailPos
@@ -204,18 +217,16 @@ abstract class TailCalls extends Transform {
*/
def fail(reason: String) = {
debuglog("Cannot rewrite recursive call at: " + fun.pos + " because: " + reason)
-
- ctx.failReason = reason
+ failReasons(ctx) = reason
treeCopy.Apply(tree, noTailTransform(target), transformArgs)
}
/* Position of failure is that of the tree being considered. */
def failHere(reason: String) = {
- ctx.failPos = fun.pos
+ failPositions(ctx) = fun.pos
fail(reason)
}
def rewriteTailCall(recv: Tree): Tree = {
debuglog("Rewriting tail recursive call: " + fun.pos.lineContent.trim)
-
accessed += ctx.label
typedPos(fun.pos) {
val args = mapWithIndex(transformArgs)((arg, i) => mkAttributedCastHack(arg, ctx.label.info.params(i + 1).tpe))
@@ -241,37 +252,23 @@ abstract class TailCalls extends Transform {
super.transform(tree)
- case dd @ DefDef(_, _, _, vparamss0, _, rhs0) if !dd.symbol.hasAccessorFlag =>
- val newCtx = new Context(dd)
- def isRecursiveCall(t: Tree) = {
- val sym = t.symbol
- (sym != null) && {
- sym.isMethod && (dd.symbol.name == sym.name) && (dd.symbol.enclClass isSubClass sym.enclClass)
- }
- }
- if (newCtx.isMandatory) {
- if (!rhs0.exists(isRecursiveCall)) {
- unit.error(tree.pos, "@tailrec annotated method contains no recursive calls")
- }
- }
-
- // labels are local to a method, so only traverse the rhs of a defdef
- val collectTailPosLabels = new TailPosLabelsTraverser
- collectTailPosLabels traverse rhs0
- newCtx.tailLabels = collectTailPosLabels.tailLabels.toSet
+ case dd @ DefDef(_, name, _, vparamss0, _, rhs0) if !dd.symbol.hasAccessorFlag =>
+ val newCtx = new DefDefTailContext(dd)
+ if (newCtx.isMandatory && !(newCtx containsRecursiveCall rhs0))
+ unit.error(tree.pos, "@tailrec annotated method contains no recursive calls")
- debuglog("Considering " + dd.name + " for tailcalls, with labels in tailpos: "+ newCtx.tailLabels)
+ debuglog(s"Considering $name for tailcalls, with labels in tailpos: ${newCtx.tailLabels}")
val newRHS = transform(rhs0, newCtx)
- deriveDefDef(tree){rhs =>
+ deriveDefDef(tree) { rhs =>
if (newCtx.isTransformed) {
/* We have rewritten the tree, but there may be nested recursive calls remaining.
* If @tailrec is given we need to fail those now.
*/
if (newCtx.isMandatory) {
for (t @ Apply(fn, _) <- newRHS ; if fn.symbol == newCtx.method) {
- newCtx.failPos = t.pos
- newCtx.tailrecFailure()
+ failPositions(newCtx) = t.pos
+ tailrecFailure(newCtx)
}
}
val newThis = newCtx.newThis(tree.pos)
@@ -283,8 +280,8 @@ abstract class TailCalls extends Transform {
))
}
else {
- if (newCtx.isMandatory && newRHS.exists(isRecursiveCall))
- newCtx.tailrecFailure()
+ if (newCtx.isMandatory && (newCtx containsRecursiveCall newRHS))
+ tailrecFailure(newCtx)
newRHS
}
@@ -345,27 +342,25 @@ abstract class TailCalls extends Transform {
case Apply(tapply @ TypeApply(fun, targs), vargs) =>
rewriteApply(tapply, fun, targs, vargs)
- case Apply(fun, args) =>
- if (fun.symbol == Boolean_or || fun.symbol == Boolean_and)
- treeCopy.Apply(tree, fun, transformTrees(args))
- else if (fun.symbol.isLabel && args.nonEmpty && args.tail.isEmpty && ctx.tailLabels(fun.symbol)) {
- // this is to detect tailcalls in translated matches
- // it's a one-argument call to a label that is in a tailposition and that looks like label(x) {x}
- // thus, the argument to the call is in tailposition
- val saved = ctx.tailPos
- ctx.tailPos = true
- debuglog("in tailpos label: "+ args.head)
- val res = transform(args.head)
- ctx.tailPos = saved
- if (res ne args.head) {
- // we tail-called -- TODO: shield from false-positives where we rewrite but don't tail-call
- // must leave the jump to the original tailpos-label (fun)!
- // there might be *a* tailcall *in* res, but it doesn't mean res *always* tailcalls
- treeCopy.Apply(tree, fun, List(res))
- }
- else rewriteApply(fun, fun, Nil, args)
- } else rewriteApply(fun, fun, Nil, args)
+ case Apply(fun, args) if fun.symbol == Boolean_or || fun.symbol == Boolean_and =>
+ treeCopy.Apply(tree, fun, transformTrees(args))
+
+ // this is to detect tailcalls in translated matches
+ // it's a one-argument call to a label that is in a tailposition and that looks like label(x) {x}
+ // thus, the argument to the call is in tailposition
+ case Apply(fun, args @ (arg :: Nil)) if fun.symbol.isLabel && ctx.tailLabels(fun.symbol) =>
+ debuglog(s"in tailpos label: $arg")
+ val res = yesTailTransform(arg)
+ // we tail-called -- TODO: shield from false-positives where we rewrite but don't tail-call
+ // must leave the jump to the original tailpos-label (fun)!
+ // there might be *a* tailcall *in* res, but it doesn't mean res *always* tailcalls
+ if (res ne arg)
+ treeCopy.Apply(tree, fun, res :: Nil)
+ else
+ rewriteApply(fun, fun, Nil, args)
+ case Apply(fun, args) =>
+ rewriteApply(fun, fun, Nil, args)
case Alternative(_) | Star(_) | Bind(_, _) =>
sys.error("We should've never gotten inside a pattern")
case Select(qual, name) =>

0 comments on commit 1ce4ecd

Please sign in to comment.