Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions src/dotty/tools/dotc/transform/TailRec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,20 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
final val labelPrefix = "tailLabel"
final val labelFlags = Flags.Synthetic | Flags.Label

/** Symbols of methods that have @tailrec annotatios inside */
private val methodsWithInnerAnnots = new collection.mutable.HashSet[Symbol]()

override def transformUnit(tree: Tree)(implicit ctx: Context, info: TransformerInfo): Tree = {
methodsWithInnerAnnots.clear()
tree
}

override def transformTyped(tree: Typed)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (tree.tpt.tpe.hasAnnotation(defn.TailrecAnnot))
methodsWithInnerAnnots += ctx.owner.enclosingMethod
tree
}

private def mkLabel(method: Symbol, abstractOverClass: Boolean)(implicit c: Context): TermSymbol = {
val name = c.freshName(labelPrefix)

Expand Down Expand Up @@ -137,10 +151,10 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
}
})
}
case d: DefDef if d.symbol.hasAnnotation(defn.TailrecAnnot) =>
case d: DefDef if d.symbol.hasAnnotation(defn.TailrecAnnot) || methodsWithInnerAnnots.contains(d.symbol) =>
ctx.error("TailRec optimisation not applicable, method is neither private nor final so can be overridden", d.pos)
d
case d if d.symbol.hasAnnotation(defn.TailrecAnnot) =>
case d if d.symbol.hasAnnotation(defn.TailrecAnnot) || methodsWithInnerAnnots.contains(d.symbol) =>
ctx.error("TailRec optimisation not applicable, not a method", d.pos)
d
case _ => tree
Expand Down Expand Up @@ -180,7 +194,7 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete

override def transform(tree: Tree)(implicit c: Context): Tree = {
/* A possibly polymorphic apply to be considered for tail call transformation. */
def rewriteApply(tree: Tree, sym: Symbol): Tree = {
def rewriteApply(tree: Tree, sym: Symbol, required: Boolean = false): Tree = {
def receiverArgumentsAndSymbol(t: Tree, accArgs: List[List[Tree]] = Nil, accT: List[Tree] = Nil):
(Tree, Tree, List[List[Tree]], List[Tree], Symbol) = t match {
case TypeApply(fun, targs) if fun.symbol eq t.symbol => receiverArgumentsAndSymbol(fun, accArgs, targs)
Expand Down Expand Up @@ -216,7 +230,7 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
}
}
def fail(reason: String) = {
if (isMandatory) c.error(s"Cannot rewrite recursive call: $reason", tree.pos)
if (isMandatory || required) c.error(s"Cannot rewrite recursive call: $reason", tree.pos)
else c.debuglog("Cannot rewrite recursive call at: " + tree.pos + " because: " + reason)
continue
}
Expand Down Expand Up @@ -299,7 +313,8 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
noTailTransforms(stats),
transform(expr)
)

case tree @ Typed(t: Apply, tpt) if tpt.tpe.hasAnnotation(defn.TailrecAnnot) =>
tpd.Typed(rewriteApply(t, t.fun.symbol, required = true), tpt)
case tree@If(cond, thenp, elsep) =>
tpd.cpy.If(tree)(
noTailTransform(cond),
Expand Down
10 changes: 10 additions & 0 deletions tests/neg/tailcall/i1221.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import annotation.tailrec

object I1221{
final def foo(a: Int): Int = {
if ((foo(a - 1): @tailrec) > 0) // error: not in tail position
foo(a - 1): @tailrec
else
foo(a - 2): @tailrec
}
}
10 changes: 10 additions & 0 deletions tests/neg/tailcall/i1221b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import annotation.tailrec

class Test {
def foo(a: Int): Int = { // error: method is not final
if ((foo(a - 1): @tailrec) > 0)
foo(a - 1): @tailrec
else
foo(a - 2): @tailrec
}
}
10 changes: 10 additions & 0 deletions tests/pos/tailcall/i1221.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import annotation.tailrec

object i1221{
final def foo(a: Int): Int = {
if (foo(a - 1) > 0)
foo(a - 1): @tailrec
else
foo(a - 2): @tailrec
}
}