diff --git a/src/compiler/scala/tools/nsc/transform/CleanUp.scala b/src/compiler/scala/tools/nsc/transform/CleanUp.scala index 55cf697176da..0975452fa7bb 100644 --- a/src/compiler/scala/tools/nsc/transform/CleanUp.scala +++ b/src/compiler/scala/tools/nsc/transform/CleanUp.scala @@ -61,9 +61,6 @@ abstract class CleanUp extends Statics with Transform with ast.TreeDSL { } private def mkTerm(prefix: String): TermName = unit.freshTermName(prefix) - //private val classConstantMeth = new HashMap[String, Symbol] - //private val symbolStaticFields = new HashMap[String, (Symbol, Tree, Tree)] - private var localTyper: analyzer.Typer = null private def typedWithPos(pos: Position)(tree: Tree) = @@ -383,6 +380,106 @@ abstract class CleanUp extends Statics with Transform with ast.TreeDSL { } } + object StringsPattern { + def unapply(arg: Tree): Option[List[String]] = arg match { + case Literal(Constant(value: String)) => Some(value :: Nil) + case Literal(Constant(null)) => Some(null :: Nil) + case Alternative(alts) => traverseOpt(alts)(unapply).map(_.flatten) + case _ => None + } + } + + // transform scrutinee of all matches to ints + def transformSwitch(sw: Match): Tree = { import CODE._ + sw.selector.tpe match { + case IntTpe => sw // can switch directly on ints + case StringTpe => + // these assumptions about the shape of the tree are justified by the codegen in MatchOptimization + val Match(Typed(selTree: Ident, _), cases) = sw + val sel = selTree.symbol + val restpe = sw.tpe + val swPos = sw.pos.focus + + /* From this: + * string match { case "AaAa" => 1 case "BBBB" | "c" => 2 case _ => 3} + * Generate this: + * goto matchSuccess (string.## match { + * case 2031744 => + * if ("AaAa" equals string) goto match1 + * else if ("BBBB" equals string) goto match2 + * else goto matchFailure + * case 99 => + * if ("c" equals string) goto match2 + * else goto matchFailure + * case _ => goto matchFailure + * } + * match1: goto matchSuccess (1) + * match2: goto matchSuccess (2) + * matchFailure: goto matchSuccess (3) // would be throw new MatchError(string) if no default was given + * matchSuccess(res: Int): res + * This proliferation of labels is needed to handle alternative patterns, since multiple branches in the + * resulting switch may need to correspond to a single case body. + */ + + val stats = mutable.ListBuffer.empty[Tree] + var failureBody = Throw(New(definitions.MatchErrorClass.tpe_*, REF(sel))) : Tree + + // genbcode isn't thrilled about seeing labels with Unit arguments, so `success`'s type is one of + // `${sw.tpe} => ${sw.tpe}` or `() => Unit` depending. + val success = { + val lab = currentOwner.newLabel(unit.freshTermName("matchEnd"), swPos) + if (restpe =:= UnitTpe) { + lab.setInfo(MethodType(Nil, restpe)) + } else { + lab.setInfo(MethodType(lab.newValueParameter(nme.x_1).setInfo(restpe) :: Nil, restpe)) + } + } + def succeed(res: Tree): Tree = + if (restpe =:= UnitTpe) BLOCK(res, REF(success) APPLY Nil) else REF(success) APPLY res + + val failure = currentOwner.newLabel(unit.freshTermName("matchEnd"), swPos).setInfo(MethodType(Nil, restpe)) + def fail(): Tree = atPos(swPos) { Apply(REF(failure), Nil) } + + val newSel = atPos(sel.pos) { IF (sel OBJ_EQ NULL) THEN LIT(0) ELSE (Apply(REF(sel) DOT Object_hashCode, Nil)) } + val casesByHash = + cases.flatMap { + case cd@CaseDef(StringsPattern(strs), _, body) => + val jump = currentOwner.newLabel(unit.freshTermName("case"), swPos).setInfo(MethodType(Nil, restpe)) + stats += LabelDef(jump, Nil, succeed(body)) + strs.map((_, jump, cd.pat.pos)) + case cd@CaseDef(Ident(nme.WILDCARD), _, body) => + failureBody = succeed(body) + None + case cd => globalError(s"unhandled in switch: $cd"); None + }.groupBy(_._1.##) + val newCases = casesByHash.toList.sortBy(_._1).map { + case (hash, cases) => + val newBody = cases.foldLeft(fail()) { + case (next, (pat, jump, pos)) => + val comparison = if (pat == null) Object_eq else Object_equals + atPos(pos) { + IF(LIT(pat) DOT comparison APPLY REF(sel)) THEN (REF(jump) APPLY Nil) ELSE next + } + } + CaseDef(LIT(hash), EmptyTree, newBody) + } + + stats += LabelDef(failure, Nil, failureBody) + + stats += (if (restpe =:= UnitTpe) { + LabelDef(success, Nil, gen.mkLiteralUnit) + } else { + LabelDef(success, success.info.params.head :: Nil, REF(success.info.params.head)) + }) + + stats prepend Match(newSel, newCases :+ CaseDef(Ident(nme.WILDCARD), EmptyTree, fail())) + + val res = Block(stats.result : _*) + localTyper.typedPos(sw.pos)(res) + case _ => globalError(s"unhandled switch scrutinee type ${sw.selector.tpe}: $sw"); sw + } + } + override def transform(tree: Tree): Tree = tree match { case _: ClassDef if genBCode.codeGen.CodeGenImpl.isJavaEntryPoint(tree.symbol, currentUnit, settings.mainClass.valueSetByUser.map(_.toString)) => // collecting symbols for entry points here (as opposed to GenBCode where they are used) @@ -498,6 +595,9 @@ abstract class CleanUp extends Statics with Transform with ast.TreeDSL { super.transform(localTyper.typedPos(tree.pos)(consed)) } + case switch: Match => + super.transform(transformSwitch(switch)) + case _ => super.transform(tree) } diff --git a/src/compiler/scala/tools/nsc/transform/Erasure.scala b/src/compiler/scala/tools/nsc/transform/Erasure.scala index 9e355f8a53f4..7690a734e03b 100644 --- a/src/compiler/scala/tools/nsc/transform/Erasure.scala +++ b/src/compiler/scala/tools/nsc/transform/Erasure.scala @@ -1285,7 +1285,7 @@ abstract class Erasure extends InfoTransform treeCopy.Template(tree, parents, noSelfType, addBridgesToTemplate(body, currentOwner)) case Match(selector, cases) => - Match(Typed(selector, TypeTree(selector.tpe)), cases) + treeCopy.Match(tree, Typed(selector, TypeTree(selector.tpe)), cases) case Literal(ct) => // We remove the original tree attachments in pre-erasure to free up memory diff --git a/src/compiler/scala/tools/nsc/transform/patmat/MatchOptimization.scala b/src/compiler/scala/tools/nsc/transform/patmat/MatchOptimization.scala index c410e8040316..5912586d645b 100644 --- a/src/compiler/scala/tools/nsc/transform/patmat/MatchOptimization.scala +++ b/src/compiler/scala/tools/nsc/transform/patmat/MatchOptimization.scala @@ -499,19 +499,33 @@ trait MatchOptimization extends MatchTreeMaking with MatchAnalysis { } } - class RegularSwitchMaker(scrutSym: Symbol, matchFailGenOverride: Option[Tree => Tree], val unchecked: Boolean) extends SwitchMaker { - val switchableTpe = Set(ByteTpe, ShortTpe, IntTpe, CharTpe) + class RegularSwitchMaker(scrutSym: Symbol, matchFailGenOverride: Option[Tree => Tree], val unchecked: Boolean) extends SwitchMaker { import CODE._ + val switchableTpe = Set(ByteTpe, ShortTpe, IntTpe, CharTpe, StringTpe) val alternativesSupported = true val canJump = true // Constant folding sets the type of a constant tree to `ConstantType(Constant(folded))` // The tree itself can be a literal, an ident, a selection, ... object SwitchablePattern { def unapply(pat: Tree): Option[Tree] = pat.tpe match { - case const: ConstantType if const.value.isIntRange => - Some(Literal(Constant(const.value.intValue))) // TODO: Java 7 allows strings in switches + case const: ConstantType => + if (const.value.isIntRange) + Some(LIT(const.value.intValue) setPos pat.pos) + else if (const.value.tag == StringTag) + Some(LIT(const.value.stringValue) setPos pat.pos) + else if (const.value.tag == NullTag) + Some(LIT(null) setPos pat.pos) + else None case _ => None }} + def scrutRef(scrut: Symbol): Tree = dealiasWiden(scrut.tpe) match { + case subInt if subInt =:= IntTpe => + REF(scrut) + case subInt if definitions.isNumericSubClass(subInt.typeSymbol, IntClass) => + REF(scrut) DOT nme.toInt + case _ => REF(scrut) + } + object SwitchableTreeMaker extends SwitchableTreeMakerExtractor { def unapply(x: TreeMaker): Option[Tree] = x match { case EqualityTestTreeMaker(_, SwitchablePattern(const), _) => Some(const) @@ -525,8 +539,8 @@ trait MatchOptimization extends MatchTreeMaking with MatchAnalysis { } def defaultSym: Symbol = scrutSym - def defaultBody: Tree = { import CODE._; matchFailGenOverride map (gen => gen(REF(scrutSym))) getOrElse Throw(MatchErrorClass.tpe, REF(scrutSym)) } - def defaultCase(scrutSym: Symbol = defaultSym, guard: Tree = EmptyTree, body: Tree = defaultBody): CaseDef = { import CODE._; atPos(body.pos) { + def defaultBody: Tree = { matchFailGenOverride map (gen => gen(REF(scrutSym))) getOrElse Throw(MatchErrorClass.tpe, REF(scrutSym)) } + def defaultCase(scrutSym: Symbol = defaultSym, guard: Tree = EmptyTree, body: Tree = defaultBody): CaseDef = { atPos(body.pos) { (DEFAULT IF guard) ==> body }} } @@ -539,12 +553,9 @@ trait MatchOptimization extends MatchTreeMaking with MatchAnalysis { if (caseDefsWithDefault.isEmpty) None // not worth emitting a switch. else { // match on scrutSym -- converted to an int if necessary -- not on scrut directly (to avoid duplicating scrut) - val scrutToInt: Tree = - if (scrutSym.tpe =:= IntTpe) REF(scrutSym) - else (REF(scrutSym) DOT (nme.toInt)) Some(BLOCK( ValDef(scrutSym, scrut), - Match(scrutToInt, caseDefsWithDefault) // a switch + Match(regularSwitchMaker.scrutRef(scrutSym), caseDefsWithDefault) // a switch )) } } else None diff --git a/test/files/jvm/string-switch/Switch_1.scala b/test/files/jvm/string-switch/Switch_1.scala new file mode 100644 index 000000000000..f2bd17cca262 --- /dev/null +++ b/test/files/jvm/string-switch/Switch_1.scala @@ -0,0 +1,7 @@ +import annotation.switch +class Switches { + val cond = true + def two = ("foo" : @switch) match { case "foo" => case "bar" => } + def guard = ("foo" : @switch) match { case "z" => case "y" => case x if cond => } + def colli = ("foo" : @switch) match { case "DB" => case "Ca" => } +} \ No newline at end of file diff --git a/test/files/jvm/string-switch/Test.scala b/test/files/jvm/string-switch/Test.scala new file mode 100644 index 000000000000..5fed9350f2b9 --- /dev/null +++ b/test/files/jvm/string-switch/Test.scala @@ -0,0 +1,16 @@ +import scala.tools.partest.BytecodeTest +import scala.tools.asm +import scala.collection.JavaConverters._ +import scala.PartialFunction.cond + +object Test extends BytecodeTest { + def show: Unit = { + val clasz = loadClassNode("Switches") + List("two", "guard", "colli") foreach { meth => + val mn = getMethod(clasz, meth) + assert(mn.instructions.iterator.asScala.exists(isSwitchInsn), meth) + } + } + def isSwitchInsn(insn: asm.tree.AbstractInsnNode) = + cond(insn.getOpcode) { case asm.Opcodes.LOOKUPSWITCH | asm.Opcodes.TABLESWITCH => true } +} diff --git a/test/files/run/string-switch-defaults-null.check b/test/files/run/string-switch-defaults-null.check new file mode 100644 index 000000000000..4bbcfcf56827 --- /dev/null +++ b/test/files/run/string-switch-defaults-null.check @@ -0,0 +1,2 @@ +2 +-1 diff --git a/test/files/run/string-switch-defaults-null.scala b/test/files/run/string-switch-defaults-null.scala new file mode 100644 index 000000000000..9fc4ce235a2d --- /dev/null +++ b/test/files/run/string-switch-defaults-null.scala @@ -0,0 +1,16 @@ +import annotation.switch + +object Test { + def test(s: String): Int = { + (s : @switch) match { + case "1" => 0 + case null => -1 + case _ => s.toInt + } + } + + def main(args: Array[String]): Unit = { + println(test("2")) + println(test(null)) + } +} diff --git a/test/files/run/string-switch-pos.check b/test/files/run/string-switch-pos.check new file mode 100644 index 000000000000..672c83c57642 --- /dev/null +++ b/test/files/run/string-switch-pos.check @@ -0,0 +1,73 @@ +[[syntax trees at end of patmat]] // newSource1.scala +[6]package [6] { + [6]class Switch extends [13][187]scala.AnyRef { + [187]def (): [13]Switch = [187]{ + [187][187][187]Switch.super.(); + [13]() + }; + [21]def switch([28]s: [31], [39]cond: [45]): [21]Int = [56]{ + [56]case val x1: [56]String = [56]s; + [56][56]x1 match { + [56]case [75]"AaAa" => [93]1 + [56]case [104]"asdf" => [122]2 + [133]case [133]"BbBb" => [133]if ([143]cond) + [151]3 + else + [180]4 + [56]case [56]_ => [56]throw [56][56][56]new [56]MatchError([56]x1) + } + } + } +} + +[[syntax trees at end of cleanup]] // newSource1.scala +[6]package [6] { + [6]class Switch extends [13][13]Object { + [21]def switch([28]s: [31], [39]cond: [45]): [21]Int = [56]{ + [56]case val x1: [56]String = [56]s; + [56]{ + [56][56]if ([56][56]x1.eq([56]null)) + [56]0 + else + [56][56]x1.hashCode() match { + [56]case [56]2031744 => [75]if ([75][75][75]"AaAa".equals([75]x1)) + [75][75]case1() + else + [56][56]matchEnd2() + [56]case [56]2062528 => [133]if ([133][133][133]"BbBb".equals([133]x1)) + [133][133]case3() + else + [56][56]matchEnd2() + [56]case [56]3003444 => [104]if ([104][104][104]"asdf".equals([104]x1)) + [104][104]case2() + else + [56][56]matchEnd2() + [56]case [56]_ => [56][56]matchEnd2() + }; + [56]case1(){ + [56][56]matchEnd1([93]1) + }; + [56]case2(){ + [56][56]matchEnd1([122]2) + }; + [56]case3(){ + [56][56]matchEnd1([133]if ([143]cond) + [151]3 + else + [180]4) + }; + [56]matchEnd2(){ + [56][56]matchEnd1([56]throw [56][56][56]new [56]MatchError([56]x1)) + }; + [56]matchEnd1(x$1: [NoPosition]Int){ + [56]x$1 + } + } + }; + [187]def (): [13]Switch = [187]{ + [187][187][187]Switch.super.(); + [13]() + } + } +} + diff --git a/test/files/run/string-switch-pos.scala b/test/files/run/string-switch-pos.scala new file mode 100644 index 000000000000..a75208046391 --- /dev/null +++ b/test/files/run/string-switch-pos.scala @@ -0,0 +1,18 @@ +import scala.tools.partest._ + +object Test extends DirectTest { + override def extraSettings: String = "-usejavacp -stop:cleanup -Vprint:patmat,cleanup -Vprint-pos" + + override def code = + """class Switch { + | def switch(s: String, cond: Boolean) = s match { + | case "AaAa" => 1 + | case "asdf" => 2 + | case "BbBb" if cond => 3 + | case "BbBb" => 4 + | } + |} + """.stripMargin.trim + + override def show(): Unit = Console.withErr(Console.out) { super.compile() } +} \ No newline at end of file diff --git a/test/files/run/string-switch.check b/test/files/run/string-switch.check new file mode 100644 index 000000000000..7ab6b33ec0ae --- /dev/null +++ b/test/files/run/string-switch.check @@ -0,0 +1,29 @@ +fido Success(dog) +garfield Success(cat) +wanda Success(fish) +henry Success(horse) +felix Failure(scala.MatchError: felix (of class java.lang.String)) +deuteronomy Success(cat) +===== +AaAa 2031744 Success(1) +BBBB 2031744 Success(2) +BBAa 2031744 Failure(scala.MatchError: BBAa (of class java.lang.String)) +cCCc 3015872 Success(3) +ddDd 3077408 Success(4) +EEee 2125120 Failure(scala.MatchError: EEee (of class java.lang.String)) +===== +A Success(()) +X Failure(scala.MatchError: X (of class java.lang.String)) +===== + Success(3) +null Success(2) +7 Failure(scala.MatchError: 7 (of class java.lang.String)) +===== +pig Success(1) +dog Success(2) +===== +Ea 2236 Success(1) +FB 2236 Success(2) +cC 3136 Success(3) +xx 3840 Success(4) +null 0 Success(4) diff --git a/test/files/run/string-switch.scala b/test/files/run/string-switch.scala new file mode 100644 index 000000000000..95cd1c9fd85b --- /dev/null +++ b/test/files/run/string-switch.scala @@ -0,0 +1,68 @@ +// scalac: -Werror +import annotation.switch +import util.Try + +object Test extends App { + + def species(name: String) = (name.toLowerCase : @switch) match { + case "fido" => "dog" + case "garfield" | "deuteronomy" => "cat" + case "wanda" => "fish" + case "henry" => "horse" + } + List("fido", "garfield", "wanda", "henry", "felix", "deuteronomy").foreach { n => println(s"$n ${Try(species(n))}") } + + println("=====") + + def collide(in: String) = (in : @switch) match { + case "AaAa" => 1 + case "BBBB" => 2 + case "cCCc" => 3 + case x if x == "ddDd" => 4 + } + List("AaAa", "BBBB", "BBAa", "cCCc", "ddDd", "EEee").foreach { s => + println(s"$s ${s.##} ${Try(collide(s))}") + } + + println("=====") + + def unitary(in: String) = (in : @switch) match { + case "A" => + } + List("A","X").foreach { s => + println(s"$s ${Try(unitary(s))}") + } + + println("=====") + + def nullFun(in: String) = (in : @switch) match { + case "1" => 1 + case null => 2 + case "" => 3 + } + List("", null, "7").foreach { s => + println(s"$s ${Try(nullFun(s))}") + } + + println("=====") + + def default(in: String) = (in : @switch) match { + case "pig" => 1 + case _ => 2 + } + List("pig","dog").foreach { s => + println(s"$s ${Try(default(s))}") + } + + println("=====") + + def onceOnly(in: Iterator[String]) = (in.next() : @switch) match { + case "Ea" => 1 + case "FB" => 2 //collision with above + case "cC" => 3 + case _ => 4 + } + List("Ea", "FB", "cC", "xx", null).foreach { s => + println(s"$s ${s.##} ${Try(onceOnly(Iterator(s)))}") + } +} \ No newline at end of file