Skip to content

Commit

Permalink
Emit switch bytecode when matching unions of a switchable type (#20411)
Browse files Browse the repository at this point in the history
Fixes #20410
  • Loading branch information
smarter committed May 22, 2024
2 parents 7d559ad + 7279bf7 commit e0c030c
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 8 deletions.
16 changes: 8 additions & 8 deletions compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -818,11 +818,11 @@ object PatternMatcher {
*/
private def collectSwitchCases(scrutinee: Tree, plan: SeqPlan): List[(List[Tree], Plan)] = {
def isSwitchableType(tpe: Type): Boolean =
(tpe isRef defn.IntClass) ||
(tpe isRef defn.ByteClass) ||
(tpe isRef defn.ShortClass) ||
(tpe isRef defn.CharClass) ||
(tpe isRef defn.StringClass)
(tpe <:< defn.IntType) ||
(tpe <:< defn.ByteType) ||
(tpe <:< defn.ShortType) ||
(tpe <:< defn.CharType) ||
(tpe <:< defn.StringType)

val seen = mutable.Set[Any]()

Expand Down Expand Up @@ -872,7 +872,7 @@ object PatternMatcher {
(Nil, plan) :: Nil
}

if (isSwitchableType(scrutinee.tpe.widen)) recur(plan)
if (isSwitchableType(scrutinee.tpe)) recur(plan)
else Nil
}

Expand All @@ -893,8 +893,8 @@ object PatternMatcher {
*/

val (primScrutinee, scrutineeTpe) =
if (scrutinee.tpe.widen.isRef(defn.IntClass)) (scrutinee, defn.IntType)
else if (scrutinee.tpe.widen.isRef(defn.StringClass)) (scrutinee, defn.StringType)
if (scrutinee.tpe <:< defn.IntType) (scrutinee, defn.IntType)
else if (scrutinee.tpe <:< defn.StringType) (scrutinee, defn.StringType)
else (scrutinee.select(nme.toInt), defn.IntType)

def primLiteral(lit: Tree): Tree =
Expand Down
120 changes: 120 additions & 0 deletions compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,126 @@ class DottyBytecodeTests extends DottyBytecodeTest {
}
}

@Test def switchOnUnionOfInts = {
val source =
"""
|object Foo {
| def foo(x: 1 | 2 | 3 | 4 | 5) = x match {
| case 1 => println(3)
| case 2 | 3 => println(2)
| case 4 => println(1)
| case 5 => println(0)
| }
|}
""".stripMargin

checkBCode(source) { dir =>
val moduleIn = dir.lookupName("Foo$.class", directory = false)
val moduleNode = loadClassNode(moduleIn.input)
val methodNode = getMethod(moduleNode, "foo")
assert(verifySwitch(methodNode))
}
}

@Test def switchOnUnionOfStrings = {
val source =
"""
|object Foo {
| def foo(s: "one" | "two" | "three" | "four" | "five") = s match {
| case "one" => println(3)
| case "two" | "three" => println(2)
| case "four" | "five" => println(1)
| case _ => println(0)
| }
|}
""".stripMargin

checkBCode(source) { dir =>
val moduleIn = dir.lookupName("Foo$.class", directory = false)
val moduleNode = loadClassNode(moduleIn.input)
val methodNode = getMethod(moduleNode, "foo")
assert(verifySwitch(methodNode))
}
}

@Test def switchOnUnionOfChars = {
val source =
"""
|object Foo {
| def foo(ch: 'a' | 'b' | 'c' | 'd' | 'e'): Int = ch match {
| case 'a' => 1
| case 'b' => 2
| case 'c' => 3
| case 'd' => 4
| case 'e' => 5
| }
|}
""".stripMargin

checkBCode(source) { dir =>
val moduleIn = dir.lookupName("Foo$.class", directory = false)
val moduleNode = loadClassNode(moduleIn.input)
val methodNode = getMethod(moduleNode, "foo")
assert(verifySwitch(methodNode))
}
}

@Test def switchOnUnionOfIntSingletons = {
val source =
"""
|object Foo {
| final val One = 1
| final val Two = 2
| final val Three = 3
| final val Four = 4
| final val Five = 5
| type Values = One.type | Two.type | Three.type | Four.type | Five.type
|
| def foo(s: Values) = s match {
| case One => println(3)
| case Two | Three => println(2)
| case Four => println(1)
| case Five => println(0)
| }
|}
""".stripMargin

checkBCode(source) { dir =>
val moduleIn = dir.lookupName("Foo$.class", directory = false)
val moduleNode = loadClassNode(moduleIn.input)
val methodNode = getMethod(moduleNode, "foo")
assert(verifySwitch(methodNode))
}
}

@Test def switchOnUnionOfStringSingletons = {
val source =
"""
|object Foo {
| final val One = "one"
| final val Two = "two"
| final val Three = "three"
| final val Four = "four"
| final val Five = "five"
| type Values = One.type | Two.type | Three.type | Four.type | Five.type
|
| def foo(s: Values) = s match {
| case One => println(3)
| case Two | Three => println(2)
| case Four => println(1)
| case Five => println(0)
| }
|}
""".stripMargin

checkBCode(source) { dir =>
val moduleIn = dir.lookupName("Foo$.class", directory = false)
val moduleNode = loadClassNode(moduleIn.input)
val methodNode = getMethod(moduleNode, "foo")
assert(verifySwitch(methodNode))
}
}

@Test def matchWithDefaultNoThrowMatchError = {
val source =
"""class Test {
Expand Down

0 comments on commit e0c030c

Please sign in to comment.