diff --git a/compiler/src/dotty/tools/dotc/transform/TailRec.scala b/compiler/src/dotty/tools/dotc/transform/TailRec.scala index 9330016f3292..ba6f1599def6 100644 --- a/compiler/src/dotty/tools/dotc/transform/TailRec.scala +++ b/compiler/src/dotty/tools/dotc/transform/TailRec.scala @@ -277,23 +277,11 @@ class TailRec extends MiniPhase { def yesTailTransform(tree: Tree)(using Context): Tree = transform(tree, tailPosition = true) - /** If not in tail position a tree traversal may not be needed. - * - * A recursive call may still be in tail position if within the return - * expression of a labeled block. - * A tree traversal may also be needed to report a failure to transform - * a recursive call of a @tailrec annotated method (i.e. `isMandatory`). - */ - private def isTraversalNeeded = - isMandatory || tailPositionLabeledSyms.size > 0 - def noTailTransform(tree: Tree)(using Context): Tree = - if (isTraversalNeeded) transform(tree, tailPosition = false) - else tree + transform(tree, tailPosition = false) def noTailTransforms[Tr <: Tree](trees: List[Tr])(using Context): List[Tr] = - if (isTraversalNeeded) trees.mapConserve(noTailTransform).asInstanceOf[List[Tr]] - else trees + trees.mapConserve(noTailTransform).asInstanceOf[List[Tr]] override def transform(tree: Tree)(using Context): Tree = { /* Rewrite an Apply to be considered for tail call transformation. */ @@ -444,7 +432,7 @@ class TailRec extends MiniPhase { case Return(expr, from) => val fromSym = from.symbol - val inTailPosition = fromSym.is(Label) && tailPositionLabeledSyms.contains(fromSym) + val inTailPosition = !fromSym.is(Label) || tailPositionLabeledSyms.contains(fromSym) cpy.Return(tree)(transform(expr, inTailPosition), from) case _ => diff --git a/tests/run/tailrec-return.check b/tests/run/tailrec-return.check new file mode 100644 index 000000000000..361e76d8a285 --- /dev/null +++ b/tests/run/tailrec-return.check @@ -0,0 +1,7 @@ +6 +false +true +false +true +Ada Lovelace, Alan Turing +List(9, 10) diff --git a/tests/run/tailrec-return.scala b/tests/run/tailrec-return.scala new file mode 100644 index 000000000000..aa760960403d --- /dev/null +++ b/tests/run/tailrec-return.scala @@ -0,0 +1,66 @@ +object Test: + + @annotation.tailrec + def sum(n: Int, acc: Int = 0): Int = + if n != 0 then return sum(n - 1, acc + n) + acc + + @annotation.tailrec + def isEven(n: Int): Boolean = + if n != 0 && n != 1 then return isEven(n - 2) + if n == 1 then return false + true + + @annotation.tailrec + def isEvenApply(n: Int): Boolean = + // Return inside an `Apply.fun` + ( + if n != 0 && n != 1 then return isEvenApply(n - 2) + else if n == 1 then return false + else (x: Boolean) => x + )(true) + + @annotation.tailrec + def isEvenWhile(n: Int): Boolean = + // Return inside a `WhileDo.cond` + while( + if n != 0 && n != 1 then return isEvenWhile(n - 2) + else if n == 1 then return false + else true + ) {} + true + + @annotation.tailrec + def isEvenReturn(n: Int): Boolean = + // Return inside a `Return` + return + if n != 0 && n != 1 then return isEvenReturn(n - 2) + else if n == 1 then return false + else true + + @annotation.tailrec + def names(l: List[(String, String) | Null], acc: List[String] = Nil): List[String] = + l match + case Nil => acc.reverse + case x :: xs => + if x == null then return names(xs, acc) + + val displayName = x._1 + " " + x._2 + names(xs, displayName :: acc) + + def nonTail(l: List[Int]): List[Int] = + l match + case Nil => Nil + case x :: xs => + // The call to nonTail should *not* be eliminated + (x + 1) :: nonTail(xs) + + + def main(args: Array[String]): Unit = + println(sum(3)) + println(isEven(5)) + println(isEvenApply(6)) + println(isEvenWhile(7)) + println(isEvenReturn(8)) + println(names(List(("Ada", "Lovelace"), null, ("Alan", "Turing"))).mkString(", ")) + println(nonTail(List(8, 9)))