diff --git a/src/compiler/scala/tools/nsc/transform/ExtensionMethods.scala b/src/compiler/scala/tools/nsc/transform/ExtensionMethods.scala index 672d9d232a7e..56ec49e96245 100644 --- a/src/compiler/scala/tools/nsc/transform/ExtensionMethods.scala +++ b/src/compiler/scala/tools/nsc/transform/ExtensionMethods.scala @@ -208,6 +208,7 @@ abstract class ExtensionMethods extends Transform with TypingTransformers { companion.moduleClass.newMethod(extensionName, origMeth.pos, origMeth.flags & ~OVERRIDE & ~PROTECTED | FINAL) setAnnotations origMeth.annotations ) + origMeth.removeAnnotation(TailrecClass) // it's on the extension method, now. companion.info.decls.enter(extensionMeth) } @@ -221,15 +222,16 @@ abstract class ExtensionMethods extends Transform with TypingTransformers { val extensionParams = allParameters(extensionMono) val extensionThis = gen.mkAttributedStableRef(thiz setPos extensionMeth.pos) - val extensionBody = ( - rhs + val extensionBody: Tree = { + val tree = rhs .substituteSymbols(origTpeParams, extensionTpeParams) .substituteSymbols(origParams, extensionParams) .substituteThis(origThis, extensionThis) .changeOwner(origMeth -> extensionMeth) - ) + new SubstututeRecursion(origMeth, extensionMeth, unit).transform(tree) + } - // Record the extension method ( FIXME: because... ? ) + // Record the extension method. Later, in `Extender#transformStats`, these will be added to the companion object. extensionDefs(companion) += atPos(tree.pos)(DefDef(extensionMeth, extensionBody)) // These three lines are assembling Foo.bar$extension[T1, T2, ...]($this) @@ -264,4 +266,33 @@ abstract class ExtensionMethods extends Transform with TypingTransformers { stat } } + + final class SubstututeRecursion(origMeth: Symbol, extensionMeth: Symbol, + unit: CompilationUnit) extends TypingTransformer(unit) { + override def transform(tree: Tree): Tree = tree match { + // SI-6574 Rewrite recursive calls against the extension method so they can + // be tail call optimized later. The tailcalls phases comes before + // erasure, which performs this translation more generally at all call + // sites. + // + // // Source + // class C[C] { def meth[M](a: A) = { { : C[C'] }.meth[M'] } } + // + // // Translation + // class C[C] { def meth[M](a: A) = { { : C[C'] }.meth[M'](a1) } } + // object C { def meth$extension[M, C](this$: C[C], a: A) + // = { meth$extension[M', C']({ : C[C'] })(a1) } } + case treeInfo.Applied(sel @ Select(qual, _), targs, argss) if sel.symbol == origMeth => + import gen.CODE._ + localTyper.typedPos(tree.pos) { + val allArgss = List(qual) :: argss + val origThis = extensionMeth.owner.companionClass + val baseType = qual.tpe.baseType(origThis) + val allTargs = targs.map(_.tpe) ::: baseType.typeArgs + val fun = gen.mkAttributedTypeApply(THIS(extensionMeth.owner), extensionMeth, allTargs) + allArgss.foldLeft(fun)(Apply(_, _)) + } + case _ => super.transform(tree) + } + } } diff --git a/test/files/neg/t6574.check b/test/files/neg/t6574.check new file mode 100644 index 000000000000..c67b4ed80403 --- /dev/null +++ b/test/files/neg/t6574.check @@ -0,0 +1,7 @@ +t6574.scala:4: error: could not optimize @tailrec annotated method notTailPos$extension: it contains a recursive call not in tail position + println("tail") + ^ +t6574.scala:8: error: could not optimize @tailrec annotated method differentTypeArgs$extension: it is called recursively with different type arguments + {(); new Bad[String, Unit](0)}.differentTypeArgs + ^ +two errors found diff --git a/test/files/neg/t6574.scala b/test/files/neg/t6574.scala new file mode 100644 index 000000000000..bba97ad62e3f --- /dev/null +++ b/test/files/neg/t6574.scala @@ -0,0 +1,10 @@ +class Bad[X, Y](val v: Int) extends AnyVal { + @annotation.tailrec final def notTailPos[Z](a: Int)(b: String) { + this.notTailPos[Z](a)(b) + println("tail") + } + + @annotation.tailrec final def differentTypeArgs { + {(); new Bad[String, Unit](0)}.differentTypeArgs + } +} diff --git a/test/files/pos/t6574.scala b/test/files/pos/t6574.scala new file mode 100644 index 000000000000..59c1701eb4b5 --- /dev/null +++ b/test/files/pos/t6574.scala @@ -0,0 +1,19 @@ +class Bad[X, Y](val v: Int) extends AnyVal { + def vv = v + @annotation.tailrec final def foo[Z](a: Int)(b: String) { + this.foo[Z](a)(b) + } + + @annotation.tailrec final def differentReceiver { + {(); new Bad[X, Y](0)}.differentReceiver + } + + @annotation.tailrec final def dependent[Z](a: Int)(b: String): b.type = { + this.dependent[Z](a)(b) + } +} + +class HK[M[_]](val v: Int) extends AnyVal { + def hk[N[_]]: Unit = if (false) hk[M] else () +} + diff --git a/test/files/run/t6574b.check b/test/files/run/t6574b.check new file mode 100644 index 000000000000..e10fa4f810ad --- /dev/null +++ b/test/files/run/t6574b.check @@ -0,0 +1 @@ +List(5, 4, 3, 2, 1) diff --git a/test/files/run/t6574b.scala b/test/files/run/t6574b.scala new file mode 100644 index 000000000000..df329a31cac9 --- /dev/null +++ b/test/files/run/t6574b.scala @@ -0,0 +1,7 @@ +object Test extends App { + implicit class AnyOps(val i: Int) extends AnyVal { + private def parentsOf(x: Int): List[Int] = if (x == 0) Nil else x :: parentsOf(x - 1) + def parents: List[Int] = parentsOf(i) + } + println((5).parents) +}