diff --git a/compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala b/compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala index 164e602bbe1b..96ae4e8a00af 100644 --- a/compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala +++ b/compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala @@ -3,6 +3,7 @@ package backend package jvm import scala.annotation.switch +import scala.collection.mutable.SortedMap import scala.tools.asm import scala.tools.asm.{Handle, Label, Opcodes} @@ -840,61 +841,170 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { generatedType } - /* - * A Match node contains one or more case clauses, - * each case clause lists one or more Int values to use as keys, and a code block. - * Except the "default" case clause which (if it exists) doesn't list any Int key. - * - * On a first pass over the case clauses, we flatten the keys and their targets (the latter represented with asm.Labels). - * That representation allows JCodeMethodV to emit a lookupswitch or a tableswitch. - * - * On a second pass, we emit the switch blocks, one for each different target. + /* A Match node contains one or more case clauses, each case clause lists one or more + * Int/String values to use as keys, and a code block. The exception is the "default" case + * clause which doesn't list any key (there is exactly one of these per match). */ private def genMatch(tree: Match): BType = tree match { case Match(selector, cases) => lineNumber(tree) - genLoad(selector, INT) val generatedType = tpeTK(tree) + val postMatch = new asm.Label - var flatKeys: List[Int] = Nil - var targets: List[asm.Label] = Nil - var default: asm.Label = null - var switchBlocks: List[(asm.Label, Tree)] = Nil - - // collect switch blocks and their keys, but don't emit yet any switch-block. - for (caze @ CaseDef(pat, guard, body) <- cases) { - assert(guard == tpd.EmptyTree, guard) - val switchBlockPoint = new asm.Label - switchBlocks ::= (switchBlockPoint, body) - pat match { - case Literal(value) => - flatKeys ::= value.intValue - targets ::= switchBlockPoint - case Ident(nme.WILDCARD) => - assert(default == null, s"multiple default targets in a Match node, at ${tree.span}") - default = switchBlockPoint - case Alternative(alts) => - alts foreach { - case Literal(value) => - flatKeys ::= value.intValue - targets ::= switchBlockPoint - case _ => - abort(s"Invalid alternative in alternative pattern in Match node: $tree at: ${tree.span}") - } - case _ => - abort(s"Invalid pattern in Match node: $tree at: ${tree.span}") + // Only two possible selector types exist in `Match` trees at this point: Int and String + if (tpeTK(selector) == INT) { + + /* On a first pass over the case clauses, we flatten the keys and their + * targets (the latter represented with asm.Labels). That representation + * allows JCodeMethodV to emit a lookupswitch or a tableswitch. + * + * On a second pass, we emit the switch blocks, one for each different target. + */ + + var flatKeys: List[Int] = Nil + var targets: List[asm.Label] = Nil + var default: asm.Label = null + var switchBlocks: List[(asm.Label, Tree)] = Nil + + genLoad(selector, INT) + + // collect switch blocks and their keys, but don't emit yet any switch-block. + for (caze @ CaseDef(pat, guard, body) <- cases) { + assert(guard == tpd.EmptyTree, guard) + val switchBlockPoint = new asm.Label + switchBlocks ::= (switchBlockPoint, body) + pat match { + case Literal(value) => + flatKeys ::= value.intValue + targets ::= switchBlockPoint + case Ident(nme.WILDCARD) => + assert(default == null, s"multiple default targets in a Match node, at ${tree.span}") + default = switchBlockPoint + case Alternative(alts) => + alts foreach { + case Literal(value) => + flatKeys ::= value.intValue + targets ::= switchBlockPoint + case _ => + abort(s"Invalid alternative in alternative pattern in Match node: $tree at: ${tree.span}") + } + case _ => + abort(s"Invalid pattern in Match node: $tree at: ${tree.span}") + } } - } - bc.emitSWITCH(mkArrayReverse(flatKeys), mkArrayL(targets.reverse), default, MIN_SWITCH_DENSITY) + bc.emitSWITCH(mkArrayReverse(flatKeys), mkArrayL(targets.reverse), default, MIN_SWITCH_DENSITY) - // emit switch-blocks. - val postMatch = new asm.Label - for (sb <- switchBlocks.reverse) { - val (caseLabel, caseBody) = sb - markProgramPoint(caseLabel) - genLoad(caseBody, generatedType) - bc goTo postMatch + // emit switch-blocks. + for (sb <- switchBlocks.reverse) { + val (caseLabel, caseBody) = sb + markProgramPoint(caseLabel) + genLoad(caseBody, generatedType) + bc goTo postMatch + } + } else { + + /* Since the JVM doesn't have a way to switch on a string, we switch + * on the `hashCode` of the string then do an `equals` check (with a + * possible second set of jumps if blocks can be reach from multiple + * string alternatives). + * + * This mirrors the way that Java compiles `switch` on Strings. + */ + + var default: asm.Label = null + var indirectBlocks: List[(asm.Label, Tree)] = Nil + + import scala.collection.mutable + + // Cases grouped by their hashCode + val casesByHash = SortedMap.empty[Int, List[(String, Either[asm.Label, Tree])]] + var caseFallback: Tree = null + + for (caze @ CaseDef(pat, guard, body) <- cases) { + assert(guard == tpd.EmptyTree, guard) + pat match { + case Literal(value) => + val strValue = value.stringValue + casesByHash.updateWith(strValue.##) { existingCasesOpt => + val newCase = (strValue, Right(body)) + Some(newCase :: existingCasesOpt.getOrElse(Nil)) + } + case Ident(nme.WILDCARD) => + assert(default == null, s"multiple default targets in a Match node, at ${tree.span}") + default = new asm.Label + indirectBlocks ::= (default, body) + case Alternative(alts) => + // We need an extra basic block since multiple strings can lead to this code + val indirectCaseGroupLabel = new asm.Label + indirectBlocks ::= (indirectCaseGroupLabel, body) + alts foreach { + case Literal(value) => + val strValue = value.stringValue + casesByHash.updateWith(strValue.##) { existingCasesOpt => + val newCase = (strValue, Left(indirectCaseGroupLabel)) + Some(newCase :: existingCasesOpt.getOrElse(Nil)) + } + case _ => + abort(s"Invalid alternative in alternative pattern in Match node: $tree at: ${tree.span}") + } + + case _ => + abort(s"Invalid pattern in Match node: $tree at: ${tree.span}") + } + } + + // Organize the hashCode options into switch cases + var flatKeys: List[Int] = Nil + var targets: List[asm.Label] = Nil + var hashBlocks: List[(asm.Label, List[(String, Either[asm.Label, Tree])])] = Nil + for ((hashValue, hashCases) <- casesByHash) { + val switchBlockPoint = new asm.Label + hashBlocks ::= (switchBlockPoint, hashCases) + flatKeys ::= hashValue + targets ::= switchBlockPoint + } + + // Push the hashCode of the string (or `0` it is `null`) onto the stack and switch on it + genLoadIf( + If( + tree.selector.select(defn.Any_==).appliedTo(nullLiteral), + Literal(Constant(0)), + tree.selector.select(defn.Any_hashCode).appliedToNone + ), + INT + ) + bc.emitSWITCH(mkArrayReverse(flatKeys), mkArrayL(targets.reverse), default, MIN_SWITCH_DENSITY) + + // emit blocks for each hash case + for ((hashLabel, caseAlternatives) <- hashBlocks.reverse) { + markProgramPoint(hashLabel) + for ((caseString, indirectLblOrBody) <- caseAlternatives) { + val comparison = if (caseString == null) defn.Any_== else defn.Any_equals + val condp = Literal(Constant(caseString)).select(defn.Any_==).appliedTo(tree.selector) + val keepGoing = new asm.Label + indirectLblOrBody match { + case Left(jump) => + genCond(condp, jump, keepGoing, targetIfNoJump = keepGoing) + + case Right(caseBody) => + val thisCaseMatches = new asm.Label + genCond(condp, thisCaseMatches, keepGoing, targetIfNoJump = thisCaseMatches) + markProgramPoint(thisCaseMatches) + genLoad(caseBody, generatedType) + bc goTo postMatch + } + markProgramPoint(keepGoing) + } + bc goTo default + } + + // emit blocks for common patterns + for ((caseLabel, caseBody) <- indirectBlocks.reverse) { + markProgramPoint(caseLabel) + genLoad(caseBody, generatedType) + bc goTo postMatch + } } markProgramPoint(postMatch) diff --git a/compiler/src/dotty/tools/backend/sjs/JSCodeGen.scala b/compiler/src/dotty/tools/backend/sjs/JSCodeGen.scala index 183b25213b4d..d2cc27adb29f 100644 --- a/compiler/src/dotty/tools/backend/sjs/JSCodeGen.scala +++ b/compiler/src/dotty/tools/backend/sjs/JSCodeGen.scala @@ -2872,12 +2872,6 @@ class JSCodeGen()(using genCtx: Context) { def abortMatch(msg: String): Nothing = throw new FatalError(s"$msg in switch-like pattern match at ${tree.span}: $tree") - /* Although GenBCode adapts the scrutinee and the cases to `int`, only - * true `int`s can reach the back-end, as asserted by the String-switch - * transformation in `cleanup`. Therefore, we do not adapt, preserving - * the `string`s and `null`s that come out of the pattern matching in - * Scala 2.13.2+. - */ val genSelector = genExpr(selector) // Sanity check: we can handle Ints and Strings (including `null`s), but nothing else @@ -2934,11 +2928,6 @@ class JSCodeGen()(using genCtx: Context) { * When no optimization applies, and any of the case values is not a * literal int, we emit a series of `if..else` instead of a `js.Match`. * This became necessary in 2.13.2 with strings and nulls. - * - * Note that dotc has not adopted String-switch-Matches yet, so these code - * paths are dead code at the moment. However, they already existed in the - * scalac, so were ported, to be immediately available and working when - * dotc starts emitting switch-Matches on Strings. */ def isInt(tree: js.Tree): Boolean = tree.tpe == jstpe.IntType diff --git a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala index ea72f476eb18..536c566bfead 100644 --- a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala +++ b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala @@ -20,7 +20,7 @@ import util.Property._ /** The pattern matching transform. * After this phase, the only Match nodes remaining in the code are simple switches - * where every pattern is an integer constant + * where every pattern is an integer or string constant */ class PatternMatcher extends MiniPhase { import ast.tpd._ @@ -768,13 +768,15 @@ object PatternMatcher { (tpe isRef defn.IntClass) || (tpe isRef defn.ByteClass) || (tpe isRef defn.ShortClass) || - (tpe isRef defn.CharClass) + (tpe isRef defn.CharClass) || + (tpe isRef defn.StringClass) - val seen = mutable.Set[Int]() + val seen = mutable.Set[Any]() - def isNewIntConst(tree: Tree) = tree match { - case Literal(const) if const.isIntRange && !seen.contains(const.intValue) => - seen += const.intValue + def isNewSwitchableConst(tree: Tree) = tree match { + case Literal(const) + if (const.isIntRange || const.tag == Constants.StringTag) && !seen.contains(const.value) => + seen += const.value true case _ => false @@ -789,7 +791,7 @@ object PatternMatcher { val alts = List.newBuilder[Tree] def rec(innerPlan: Plan): Boolean = innerPlan match { case SeqPlan(TestPlan(EqualTest(tree), scrut, _, ReturnPlan(`innerLabel`)), tail) - if scrut === scrutinee && isNewIntConst(tree) => + if scrut === scrutinee && isNewSwitchableConst(tree) => alts += tree rec(tail) case ReturnPlan(`outerLabel`) => @@ -809,7 +811,7 @@ object PatternMatcher { def recur(plan: Plan): List[(List[Tree], Plan)] = plan match { case SeqPlan(testPlan @ TestPlan(EqualTest(tree), scrut, _, ons), tail) - if scrut === scrutinee && !canFallThrough(ons) && isNewIntConst(tree) => + if scrut === scrutinee && !canFallThrough(ons) && isNewSwitchableConst(tree) => (tree :: Nil, ons) :: recur(tail) case SeqPlan(AlternativesPlan(alts, ons), tail) => (alts, ons) :: recur(tail) @@ -832,29 +834,32 @@ object PatternMatcher { /** Emit a switch-match */ private def emitSwitchMatch(scrutinee: Tree, cases: List[(List[Tree], Plan)]): Match = { - /* Make sure to adapt the scrutinee to Int, as well as all the alternatives - * of all cases, so that only Matches on pritimive Ints survive this phase. + /* Make sure to adapt the scrutinee to Int or String, as well as all the + * alternatives, so that only Matches on pritimive Ints or Strings survive + * this phase. */ - val intScrutinee = - if (scrutinee.tpe.widen.isRef(defn.IntClass)) scrutinee - else scrutinee.select(nme.toInt) + val (primScrutinee, scrutineeTpe) = + if (scrutinee.tpe.widen.isRef(defn.IntClass)) (scrutinee, defn.IntType) + else if (scrutinee.tpe.widen.isRef(defn.StringClass)) (scrutinee, defn.StringType) + else (scrutinee.select(nme.toInt), defn.IntType) - def intLiteral(lit: Tree): Tree = + def primLiteral(lit: Tree): Tree = val Literal(constant) = lit if (constant.tag == Constants.IntTag) lit + else if (constant.tag == Constants.StringTag) lit else cpy.Literal(lit)(Constant(constant.intValue)) val caseDefs = cases.map { (alts, ons) => val pat = alts match { - case alt :: Nil => intLiteral(alt) - case Nil => Underscore(defn.IntType) // default case - case _ => Alternative(alts.map(intLiteral)) + case alt :: Nil => primLiteral(alt) + case Nil => Underscore(scrutineeTpe) // default case + case _ => Alternative(alts.map(primLiteral)) } CaseDef(pat, EmptyTree, emit(ons)) } - Match(intScrutinee, caseDefs) + Match(primScrutinee, caseDefs) } /** If selfCheck is `true`, used to check whether a tree gets generated twice */ diff --git a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala index 22fcacbb6a9c..2b1b5b57b1fa 100644 --- a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala +++ b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala @@ -118,6 +118,28 @@ class TestBCode extends DottyBytecodeTest { } } + @Test def switchOnStrings = { + val source = + """ + |object Foo { + | import scala.annotation.switch + | def foo(s: String) = s match { + | case "AaAa" => println(3) + | case "BBBB" | "c" => println(2) + | case "D" | "E" => println(1) + | case _ => println(0) + | } + |} + """.stripMargin + + checkBCode(source) { dir => + val moduleIn = dir.lookupName("Foo$.class", directory = false) + val moduleNode = loadClassNode(moduleIn.input) + val methodNode = getMethod(moduleNode, "foo") + assert(verifySwitch(methodNode)) + } + } + @Test def matchWithDefaultNoThrowMatchError = { val source = """class Test { diff --git a/tests/run/string-switch-defaults-null.check b/tests/run/string-switch-defaults-null.check new file mode 100644 index 000000000000..4bbcfcf56827 --- /dev/null +++ b/tests/run/string-switch-defaults-null.check @@ -0,0 +1,2 @@ +2 +-1 diff --git a/tests/run/string-switch-defaults-null.scala b/tests/run/string-switch-defaults-null.scala new file mode 100644 index 000000000000..9fc4ce235a2d --- /dev/null +++ b/tests/run/string-switch-defaults-null.scala @@ -0,0 +1,16 @@ +import annotation.switch + +object Test { + def test(s: String): Int = { + (s : @switch) match { + case "1" => 0 + case null => -1 + case _ => s.toInt + } + } + + def main(args: Array[String]): Unit = { + println(test("2")) + println(test(null)) + } +} diff --git a/tests/run/string-switch.check b/tests/run/string-switch.check new file mode 100644 index 000000000000..7ab6b33ec0ae --- /dev/null +++ b/tests/run/string-switch.check @@ -0,0 +1,29 @@ +fido Success(dog) +garfield Success(cat) +wanda Success(fish) +henry Success(horse) +felix Failure(scala.MatchError: felix (of class java.lang.String)) +deuteronomy Success(cat) +===== +AaAa 2031744 Success(1) +BBBB 2031744 Success(2) +BBAa 2031744 Failure(scala.MatchError: BBAa (of class java.lang.String)) +cCCc 3015872 Success(3) +ddDd 3077408 Success(4) +EEee 2125120 Failure(scala.MatchError: EEee (of class java.lang.String)) +===== +A Success(()) +X Failure(scala.MatchError: X (of class java.lang.String)) +===== + Success(3) +null Success(2) +7 Failure(scala.MatchError: 7 (of class java.lang.String)) +===== +pig Success(1) +dog Success(2) +===== +Ea 2236 Success(1) +FB 2236 Success(2) +cC 3136 Success(3) +xx 3840 Success(4) +null 0 Success(4) diff --git a/tests/run/string-switch.scala b/tests/run/string-switch.scala new file mode 100644 index 000000000000..6a1522b416d9 --- /dev/null +++ b/tests/run/string-switch.scala @@ -0,0 +1,69 @@ +// scalac: -Werror +import annotation.switch +import util.Try + +object Test extends App { + + def species(name: String) = (name.toLowerCase : @switch) match { + case "fido" => "dog" + case "garfield" | "deuteronomy" => "cat" + case "wanda" => "fish" + case "henry" => "horse" + } + List("fido", "garfield", "wanda", "henry", "felix", "deuteronomy").foreach { n => println(s"$n ${Try(species(n))}") } + + println("=====") + + def collide(in: String) = (in : @switch) match { + case "AaAa" => 1 + case "BBBB" => 2 + case "cCCc" => 3 + case x if x == "ddDd" => 4 + } + List("AaAa", "BBBB", "BBAa", "cCCc", "ddDd", "EEee").foreach { s => + println(s"$s ${s.##} ${Try(collide(s))}") + } + + println("=====") + + def unitary(in: String) = (in : @switch) match { + case "A" => + case x => throw new MatchError(x) + } + List("A","X").foreach { s => + println(s"$s ${Try(unitary(s))}") + } + + println("=====") + + def nullFun(in: String) = (in : @switch) match { + case "1" => 1 + case null => 2 + case "" => 3 + } + List("", null, "7").foreach { s => + println(s"$s ${Try(nullFun(s))}") + } + + println("=====") + + def default(in: String) = (in : @switch) match { + case "pig" => 1 + case _ => 2 + } + List("pig","dog").foreach { s => + println(s"$s ${Try(default(s))}") + } + + println("=====") + + def onceOnly(in: Iterator[String]) = (in.next() : @switch) match { + case "Ea" => 1 + case "FB" => 2 //collision with above + case "cC" => 3 + case _ => 4 + } + List("Ea", "FB", "cC", "xx", null).foreach { s => + println(s"$s ${s.##} ${Try(onceOnly(Iterator(s)))}") + } +}