From 4368847a98916c1c675f856b86ab337b16b64a24 Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Wed, 29 Dec 2021 16:27:00 +0100 Subject: [PATCH] Visit all trees --- .../dotty/tools/dotc/transform/TailRec.scala | 16 +----- tests/run/tailrec-return.check | 5 ++ tests/run/tailrec-return.scala | 50 +++++++++++++++++++ 3 files changed, 57 insertions(+), 14 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/transform/TailRec.scala b/compiler/src/dotty/tools/dotc/transform/TailRec.scala index eb239f00bcce..ba6f1599def6 100644 --- a/compiler/src/dotty/tools/dotc/transform/TailRec.scala +++ b/compiler/src/dotty/tools/dotc/transform/TailRec.scala @@ -277,23 +277,11 @@ class TailRec extends MiniPhase { def yesTailTransform(tree: Tree)(using Context): Tree = transform(tree, tailPosition = true) - /** If not in tail position a tree traversal may not be needed. - * - * A recursive call may still be in tail position if within the return - * expression of a labeled block. - * A tree traversal may also be needed to report a failure to transform - * a recursive call of a @tailrec annotated method (i.e. `isMandatory`). - */ - private def isTraversalNeeded = - isMandatory || tailPositionLabeledSyms.size > 0 - def noTailTransform(tree: Tree)(using Context): Tree = - if (isTraversalNeeded) transform(tree, tailPosition = false) - else tree + transform(tree, tailPosition = false) def noTailTransforms[Tr <: Tree](trees: List[Tr])(using Context): List[Tr] = - if (isTraversalNeeded) trees.mapConserve(noTailTransform).asInstanceOf[List[Tr]] - else trees + trees.mapConserve(noTailTransform).asInstanceOf[List[Tr]] override def transform(tree: Tree)(using Context): Tree = { /* Rewrite an Apply to be considered for tail call transformation. */ diff --git a/tests/run/tailrec-return.check b/tests/run/tailrec-return.check index 6b81a566d1a7..361e76d8a285 100644 --- a/tests/run/tailrec-return.check +++ b/tests/run/tailrec-return.check @@ -1,2 +1,7 @@ 6 false +true +false +true +Ada Lovelace, Alan Turing +List(9, 10) diff --git a/tests/run/tailrec-return.scala b/tests/run/tailrec-return.scala index 53b5ae73b82d..aa760960403d 100644 --- a/tests/run/tailrec-return.scala +++ b/tests/run/tailrec-return.scala @@ -11,6 +11,56 @@ object Test: if n == 1 then return false true + @annotation.tailrec + def isEvenApply(n: Int): Boolean = + // Return inside an `Apply.fun` + ( + if n != 0 && n != 1 then return isEvenApply(n - 2) + else if n == 1 then return false + else (x: Boolean) => x + )(true) + + @annotation.tailrec + def isEvenWhile(n: Int): Boolean = + // Return inside a `WhileDo.cond` + while( + if n != 0 && n != 1 then return isEvenWhile(n - 2) + else if n == 1 then return false + else true + ) {} + true + + @annotation.tailrec + def isEvenReturn(n: Int): Boolean = + // Return inside a `Return` + return + if n != 0 && n != 1 then return isEvenReturn(n - 2) + else if n == 1 then return false + else true + + @annotation.tailrec + def names(l: List[(String, String) | Null], acc: List[String] = Nil): List[String] = + l match + case Nil => acc.reverse + case x :: xs => + if x == null then return names(xs, acc) + + val displayName = x._1 + " " + x._2 + names(xs, displayName :: acc) + + def nonTail(l: List[Int]): List[Int] = + l match + case Nil => Nil + case x :: xs => + // The call to nonTail should *not* be eliminated + (x + 1) :: nonTail(xs) + + def main(args: Array[String]): Unit = println(sum(3)) println(isEven(5)) + println(isEvenApply(6)) + println(isEvenWhile(7)) + println(isEvenReturn(8)) + println(names(List(("Ada", "Lovelace"), null, ("Alan", "Turing"))).mkString(", ")) + println(nonTail(List(8, 9)))