diff --git a/src/compiler/scala/tools/nsc/transform/async/AsyncPhase.scala b/src/compiler/scala/tools/nsc/transform/async/AsyncPhase.scala index 670b59494d2b..d1269685b9a5 100644 --- a/src/compiler/scala/tools/nsc/transform/async/AsyncPhase.scala +++ b/src/compiler/scala/tools/nsc/transform/async/AsyncPhase.scala @@ -31,6 +31,8 @@ abstract class AsyncPhase extends Transform with TypingTransformers with AnfTran stateDiagram: ((Symbol, Tree) => Option[String => Unit]), allowExceptionsToPropagate: Boolean) extends PlainAttachment + def hasAsyncAttachment(dd: DefDef) = dd.hasAttachment[AsyncAttachment] + // Optimization: avoid the transform altogether if there are no async blocks in a unit. private val sourceFilesToTransform = perRunCaches.newSet[SourceFile]() private val awaits: mutable.Set[Symbol] = perRunCaches.newSet[Symbol]() diff --git a/src/compiler/scala/tools/nsc/transform/patmat/MatchOptimization.scala b/src/compiler/scala/tools/nsc/transform/patmat/MatchOptimization.scala index 6c320fd01e1d..9f51e4cd22dc 100644 --- a/src/compiler/scala/tools/nsc/transform/patmat/MatchOptimization.scala +++ b/src/compiler/scala/tools/nsc/transform/patmat/MatchOptimization.scala @@ -209,6 +209,8 @@ trait MatchOptimization extends MatchTreeMaking with MatchApproximation { trait SwitchEmission extends TreeMakers with MatchMonadInterface { import treeInfo.isGuardedCase + def inAsync: Boolean + abstract class SwitchMaker { abstract class SwitchableTreeMakerExtractor { def unapply(x: TreeMaker): Option[Tree] } val SwitchableTreeMaker: SwitchableTreeMakerExtractor @@ -497,7 +499,7 @@ trait MatchOptimization extends MatchTreeMaking with MatchApproximation { class RegularSwitchMaker(scrutSym: Symbol, matchFailGenOverride: Option[Tree => Tree], val unchecked: Boolean) extends SwitchMaker { import CODE._ val switchableTpe = Set(ByteTpe, ShortTpe, IntTpe, CharTpe, StringTpe) val alternativesSupported = true - val canJump = true + val canJump = !inAsync // Constant folding sets the type of a constant tree to `ConstantType(Constant(folded))` // The tree itself can be a literal, an ident, a selection, ... diff --git a/src/compiler/scala/tools/nsc/transform/patmat/PatternMatching.scala b/src/compiler/scala/tools/nsc/transform/patmat/PatternMatching.scala index a0b53cb0de90..e6a9d9d95418 100644 --- a/src/compiler/scala/tools/nsc/transform/patmat/PatternMatching.scala +++ b/src/compiler/scala/tools/nsc/transform/patmat/PatternMatching.scala @@ -65,7 +65,17 @@ trait PatternMatching extends Transform def newTransformer(unit: CompilationUnit): AstTransformer = new MatchTransformer(unit) class MatchTransformer(unit: CompilationUnit) extends TypingTransformer(unit) { + private var inAsync = false + override def transform(tree: Tree): Tree = tree match { + case dd: DefDef if async.hasAsyncAttachment(dd) => + val wasInAsync = inAsync + try { + inAsync = true + super.transform(dd) + } finally + inAsync = wasInAsync + case CaseDef(UnApply(Apply(Select(qual, nme.unapply), Ident(nme.SELECTOR_DUMMY) :: Nil), (bind@Bind(b, Ident(nme.WILDCARD))) :: Nil), guard, body) if guard.isEmpty && qual.symbol == definitions.NonFatalModule => transform(treeCopy.CaseDef( @@ -103,16 +113,17 @@ trait PatternMatching extends Transform } def translator(selectorPos: Position): MatchTranslator with CodegenCore = { - new OptimizingMatchTranslator(localTyper, selectorPos) + new OptimizingMatchTranslator(localTyper, selectorPos, inAsync) } } - class OptimizingMatchTranslator(val typer: analyzer.Typer, val selectorPos: Position) extends MatchTranslator - with MatchOptimizer - with MatchAnalyzer - with Solver + class OptimizingMatchTranslator(val typer: analyzer.Typer, val selectorPos: Position, val inAsync: Boolean) + extends MatchTranslator + with MatchOptimizer + with MatchAnalyzer + with Solver } trait Debugging { diff --git a/test/async/run/switch-await-in-guard.scala b/test/async/run/switch-await-in-guard.scala new file mode 100644 index 000000000000..28678792054f --- /dev/null +++ b/test/async/run/switch-await-in-guard.scala @@ -0,0 +1,50 @@ +//> using options -Xasync + +import scala.tools.partest.async.OptionAwait._ +import org.junit.Assert._ + +object Test { + def main(args: Array[String]): Unit = { + assertEquals(Some(22), sw1(11)) + assertEquals(Some(3), sw1(3)) + + assertEquals(Some(22), sw2(11)) + assertEquals(Some(3), sw2(3)) + + assertEquals(Some(22), sw3(11)) + assertEquals(Some(44), sw3(22)) + assertEquals(Some(3), sw3(3)) + + assertEquals(Some("22"), swS("11")) + assertEquals(Some("3"), swS("3")) + } + + private def sw1(i: Int) = optionally { + i match { + case 11 if value(Some(430)) > 42 => 22 + case p => p + } + } + + private def sw2(i: Int) = optionally { + i match { + case 11 => if (value(Some(430)) > 42) 22 else i + case p => p + } + } + + private def sw3(i: Int) = optionally { + i match { + case 11 => if (value(Some(430)) > 42) 22 else i + case 22 | 33 => 44 + case p => p + } + } + + private def swS(s: String) = optionally { + s match { + case "11" if value(Some(430)) > 42 => "22" + case p => p + } + } +} diff --git a/test/junit/scala/tools/nsc/backend/jvm/BytecodeTest.scala b/test/junit/scala/tools/nsc/backend/jvm/BytecodeTest.scala index ee152eefec13..d30901331fb0 100644 --- a/test/junit/scala/tools/nsc/backend/jvm/BytecodeTest.scala +++ b/test/junit/scala/tools/nsc/backend/jvm/BytecodeTest.scala @@ -1003,4 +1003,46 @@ class BytecodeTest extends BytecodeTesting { val lines = compileMethod(c1).instructions.collect { case l: LineNumber => l } assertSameCode(List(LineNumber(2, Label(0))), lines) } + + + @Test + def t12990(): Unit = { + val komp = BytecodeTesting.newCompiler(extraArgs = "-Xasync") + val code = + """import scala.tools.nsc.OptionAwait._ + | + |class C { + | def sw1(i: Int) = optionally { + | i match { + | case 11 if value(Some(430)) > 42 => 22 + | case p => p + | } + | } + | def sw2(i: Int) = optionally { + | i match { + | case 11 => if (value(Some(430)) > 42) 22 else i + | case p => p + | } + | } + | def sw3(i: Int) = optionally { + | i match { + | case 11 => if (value(Some(430)) > 42) 22 else i + | case 22 | 33 => 44 + | case p => p + | } + | } + |} + |""".stripMargin + val cs = komp.compileClasses(code) + + val sm1 = getMethod(cs.find(_.name == "C$stateMachine$async$1").get, "apply") + assertSame(1, sm1.instructions.count(_.opcode == TABLESWITCH)) + + val sm2 = getMethod(cs.find(_.name == "C$stateMachine$async$2").get, "apply") + assertSame(2, sm2.instructions.count(_.opcode == TABLESWITCH)) + + val sm3 = getMethod(cs.find(_.name == "C$stateMachine$async$3").get, "apply") + assertSame(1, sm3.instructions.count(_.opcode == TABLESWITCH)) + assertSame(1, sm3.instructions.count(_.opcode == LOOKUPSWITCH)) + } }