Skip to content

Commit

Permalink
Merge pull request #2717 from retronym/ticket/6574
Browse files Browse the repository at this point in the history
SI-6574 Support @tailrec for extension methods.
  • Loading branch information
adriaanm committed Jul 10, 2013
2 parents f4ec281 + a90d1f0 commit 5994711
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 4 deletions.
39 changes: 35 additions & 4 deletions src/compiler/scala/tools/nsc/transform/ExtensionMethods.scala
Expand Up @@ -208,6 +208,7 @@ abstract class ExtensionMethods extends Transform with TypingTransformers {
companion.moduleClass.newMethod(extensionName, origMeth.pos, origMeth.flags & ~OVERRIDE & ~PROTECTED | FINAL)
setAnnotations origMeth.annotations
)
origMeth.removeAnnotation(TailrecClass) // it's on the extension method, now.
companion.info.decls.enter(extensionMeth)
}

Expand All @@ -221,15 +222,16 @@ abstract class ExtensionMethods extends Transform with TypingTransformers {
val extensionParams = allParameters(extensionMono)
val extensionThis = gen.mkAttributedStableRef(thiz setPos extensionMeth.pos)

val extensionBody = (
rhs
val extensionBody: Tree = {
val tree = rhs
.substituteSymbols(origTpeParams, extensionTpeParams)
.substituteSymbols(origParams, extensionParams)
.substituteThis(origThis, extensionThis)
.changeOwner(origMeth -> extensionMeth)
)
new SubstututeRecursion(origMeth, extensionMeth, unit).transform(tree)
}

// Record the extension method ( FIXME: because... ? )
// Record the extension method. Later, in `Extender#transformStats`, these will be added to the companion object.
extensionDefs(companion) += atPos(tree.pos)(DefDef(extensionMeth, extensionBody))

// These three lines are assembling Foo.bar$extension[T1, T2, ...]($this)
Expand Down Expand Up @@ -264,4 +266,33 @@ abstract class ExtensionMethods extends Transform with TypingTransformers {
stat
}
}

final class SubstututeRecursion(origMeth: Symbol, extensionMeth: Symbol,
unit: CompilationUnit) extends TypingTransformer(unit) {
override def transform(tree: Tree): Tree = tree match {
// SI-6574 Rewrite recursive calls against the extension method so they can
// be tail call optimized later. The tailcalls phases comes before
// erasure, which performs this translation more generally at all call
// sites.
//
// // Source
// class C[C] { def meth[M](a: A) = { { <expr>: C[C'] }.meth[M'] } }
//
// // Translation
// class C[C] { def meth[M](a: A) = { { <expr>: C[C'] }.meth[M'](a1) } }
// object C { def meth$extension[M, C](this$: C[C], a: A)
// = { meth$extension[M', C']({ <expr>: C[C'] })(a1) } }
case treeInfo.Applied(sel @ Select(qual, _), targs, argss) if sel.symbol == origMeth =>
import gen.CODE._
localTyper.typedPos(tree.pos) {
val allArgss = List(qual) :: argss
val origThis = extensionMeth.owner.companionClass
val baseType = qual.tpe.baseType(origThis)
val allTargs = targs.map(_.tpe) ::: baseType.typeArgs
val fun = gen.mkAttributedTypeApply(THIS(extensionMeth.owner), extensionMeth, allTargs)
allArgss.foldLeft(fun)(Apply(_, _))
}
case _ => super.transform(tree)
}
}
}
7 changes: 7 additions & 0 deletions test/files/neg/t6574.check
@@ -0,0 +1,7 @@
t6574.scala:4: error: could not optimize @tailrec annotated method notTailPos$extension: it contains a recursive call not in tail position
println("tail")
^
t6574.scala:8: error: could not optimize @tailrec annotated method differentTypeArgs$extension: it is called recursively with different type arguments
{(); new Bad[String, Unit](0)}.differentTypeArgs
^
two errors found
10 changes: 10 additions & 0 deletions test/files/neg/t6574.scala
@@ -0,0 +1,10 @@
class Bad[X, Y](val v: Int) extends AnyVal {
@annotation.tailrec final def notTailPos[Z](a: Int)(b: String) {
this.notTailPos[Z](a)(b)
println("tail")
}

@annotation.tailrec final def differentTypeArgs {
{(); new Bad[String, Unit](0)}.differentTypeArgs
}
}
19 changes: 19 additions & 0 deletions test/files/pos/t6574.scala
@@ -0,0 +1,19 @@
class Bad[X, Y](val v: Int) extends AnyVal {
def vv = v
@annotation.tailrec final def foo[Z](a: Int)(b: String) {
this.foo[Z](a)(b)
}

@annotation.tailrec final def differentReceiver {
{(); new Bad[X, Y](0)}.differentReceiver
}

@annotation.tailrec final def dependent[Z](a: Int)(b: String): b.type = {
this.dependent[Z](a)(b)
}
}

class HK[M[_]](val v: Int) extends AnyVal {
def hk[N[_]]: Unit = if (false) hk[M] else ()
}

1 change: 1 addition & 0 deletions test/files/run/t6574b.check
@@ -0,0 +1 @@
List(5, 4, 3, 2, 1)
7 changes: 7 additions & 0 deletions test/files/run/t6574b.scala
@@ -0,0 +1,7 @@
object Test extends App {
implicit class AnyOps(val i: Int) extends AnyVal {
private def parentsOf(x: Int): List[Int] = if (x == 0) Nil else x :: parentsOf(x - 1)
def parents: List[Int] = parentsOf(i)
}
println((5).parents)
}

0 comments on commit 5994711

Please sign in to comment.