Skip to content

Commit

Permalink
SI-6574 Support @tailrec for extension methods.
Browse files Browse the repository at this point in the history
Currently, when the body of an extension method is transplanted
to the companion object, recursive calls point back to the original
instance method. That changes during erasure, but this is too late
for tail call analysis/elimination.

This commit eagerly updates the recursive calls to point to the
extension method in the companion. It also removes the @tailrec
annotation from the original method.
  • Loading branch information
retronym committed Jul 10, 2013
1 parent 07fc7bb commit a90d1f0
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 4 deletions.
39 changes: 35 additions & 4 deletions src/compiler/scala/tools/nsc/transform/ExtensionMethods.scala
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ abstract class ExtensionMethods extends Transform with TypingTransformers {
companion.moduleClass.newMethod(extensionName, origMeth.pos, origMeth.flags & ~OVERRIDE & ~PROTECTED | FINAL)
setAnnotations origMeth.annotations
)
origMeth.removeAnnotation(TailrecClass) // it's on the extension method, now.
companion.info.decls.enter(extensionMeth)
}

Expand All @@ -221,15 +222,16 @@ abstract class ExtensionMethods extends Transform with TypingTransformers {
val extensionParams = allParameters(extensionMono)
val extensionThis = gen.mkAttributedStableRef(thiz setPos extensionMeth.pos)

val extensionBody = (
rhs
val extensionBody: Tree = {
val tree = rhs
.substituteSymbols(origTpeParams, extensionTpeParams)
.substituteSymbols(origParams, extensionParams)
.substituteThis(origThis, extensionThis)
.changeOwner(origMeth -> extensionMeth)
)
new SubstututeRecursion(origMeth, extensionMeth, unit).transform(tree)
}

// Record the extension method ( FIXME: because... ? )
// Record the extension method. Later, in `Extender#transformStats`, these will be added to the companion object.
extensionDefs(companion) += atPos(tree.pos)(DefDef(extensionMeth, extensionBody))

// These three lines are assembling Foo.bar$extension[T1, T2, ...]($this)
Expand Down Expand Up @@ -264,4 +266,33 @@ abstract class ExtensionMethods extends Transform with TypingTransformers {
stat
}
}

final class SubstututeRecursion(origMeth: Symbol, extensionMeth: Symbol,
unit: CompilationUnit) extends TypingTransformer(unit) {
override def transform(tree: Tree): Tree = tree match {
// SI-6574 Rewrite recursive calls against the extension method so they can
// be tail call optimized later. The tailcalls phases comes before
// erasure, which performs this translation more generally at all call
// sites.
//
// // Source
// class C[C] { def meth[M](a: A) = { { <expr>: C[C'] }.meth[M'] } }
//
// // Translation
// class C[C] { def meth[M](a: A) = { { <expr>: C[C'] }.meth[M'](a1) } }
// object C { def meth$extension[M, C](this$: C[C], a: A)
// = { meth$extension[M', C']({ <expr>: C[C'] })(a1) } }
case treeInfo.Applied(sel @ Select(qual, _), targs, argss) if sel.symbol == origMeth =>
import gen.CODE._
localTyper.typedPos(tree.pos) {
val allArgss = List(qual) :: argss
val origThis = extensionMeth.owner.companionClass
val baseType = qual.tpe.baseType(origThis)
val allTargs = targs.map(_.tpe) ::: baseType.typeArgs
val fun = gen.mkAttributedTypeApply(THIS(extensionMeth.owner), extensionMeth, allTargs)
allArgss.foldLeft(fun)(Apply(_, _))
}
case _ => super.transform(tree)
}
}
}
7 changes: 7 additions & 0 deletions test/files/neg/t6574.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
t6574.scala:4: error: could not optimize @tailrec annotated method notTailPos$extension: it contains a recursive call not in tail position
println("tail")
^
t6574.scala:8: error: could not optimize @tailrec annotated method differentTypeArgs$extension: it is called recursively with different type arguments
{(); new Bad[String, Unit](0)}.differentTypeArgs
^
two errors found
10 changes: 10 additions & 0 deletions test/files/neg/t6574.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class Bad[X, Y](val v: Int) extends AnyVal {
@annotation.tailrec final def notTailPos[Z](a: Int)(b: String) {
this.notTailPos[Z](a)(b)
println("tail")
}

@annotation.tailrec final def differentTypeArgs {
{(); new Bad[String, Unit](0)}.differentTypeArgs
}
}
19 changes: 19 additions & 0 deletions test/files/pos/t6574.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
class Bad[X, Y](val v: Int) extends AnyVal {
def vv = v
@annotation.tailrec final def foo[Z](a: Int)(b: String) {
this.foo[Z](a)(b)
}

@annotation.tailrec final def differentReceiver {
{(); new Bad[X, Y](0)}.differentReceiver
}

@annotation.tailrec final def dependent[Z](a: Int)(b: String): b.type = {
this.dependent[Z](a)(b)
}
}

class HK[M[_]](val v: Int) extends AnyVal {
def hk[N[_]]: Unit = if (false) hk[M] else ()
}

1 change: 1 addition & 0 deletions test/files/run/t6574b.check
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
List(5, 4, 3, 2, 1)
7 changes: 7 additions & 0 deletions test/files/run/t6574b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
object Test extends App {
implicit class AnyOps(val i: Int) extends AnyVal {
private def parentsOf(x: Int): List[Int] = if (x == 0) Nil else x :: parentsOf(x - 1)
def parents: List[Int] = parentsOf(i)
}
println((5).parents)
}

0 comments on commit a90d1f0

Please sign in to comment.