From 3cbc15e0dbd0d88faee3acdeeb993473cf32183d Mon Sep 17 00:00:00 2001 From: Derek Wickern Date: Tue, 14 May 2024 22:04:07 -0700 Subject: [PATCH 1/3] Emit switch bytecode when matching unions of a switchable type --- .../tools/dotc/transform/PatternMatcher.scala | 7 +- .../backend/jvm/DottyBytecodeTests.scala | 98 +++++++++++++++++++ 2 files changed, 102 insertions(+), 3 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala index 0b8507f3b6c7..1e95ca1618b2 100644 --- a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala +++ b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala @@ -868,7 +868,7 @@ object PatternMatcher { (Nil, plan) :: Nil } - if (isSwitchableType(scrutinee.tpe.widen)) recur(plan) + if (isSwitchableType(scrutinee.tpe.widen.widenSingletons())) recur(plan) else Nil } @@ -889,8 +889,9 @@ object PatternMatcher { */ val (primScrutinee, scrutineeTpe) = - if (scrutinee.tpe.widen.isRef(defn.IntClass)) (scrutinee, defn.IntType) - else if (scrutinee.tpe.widen.isRef(defn.StringClass)) (scrutinee, defn.StringType) + val tpe = scrutinee.tpe.widen.widenSingletons() + if (tpe.isRef(defn.IntClass)) (scrutinee, defn.IntType) + else if (tpe.isRef(defn.StringClass)) (scrutinee, defn.StringType) else (scrutinee.select(nme.toInt), defn.IntType) def primLiteral(lit: Tree): Tree = diff --git a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala index f446913d7964..e4e485478804 100644 --- a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala +++ b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala @@ -158,6 +158,104 @@ class DottyBytecodeTests extends DottyBytecodeTest { } } + @Test def switchOnUnionOfInts = { + val source = + """ + |object Foo { + | def foo(x: 1 | 2 | 3 | 4 | 5) = x match { + | case 1 => println(3) + | case 2 | 3 => println(2) + | case 4 => println(1) + | case 5 => println(0) + | } + |} + """.stripMargin + + checkBCode(source) { dir => + val moduleIn = dir.lookupName("Foo$.class", directory = false) + val moduleNode = loadClassNode(moduleIn.input) + val methodNode = getMethod(moduleNode, "foo") + assert(verifySwitch(methodNode)) + } + } + + @Test def switchOnUnionOfStrings = { + val source = + """ + |object Foo { + | def foo(s: "one" | "two" | "three" | "four" | "five") = s match { + | case "one" => println(3) + | case "two" | "three" => println(2) + | case "four" | "five" => println(1) + | case _ => println(0) + | } + |} + """.stripMargin + + checkBCode(source) { dir => + val moduleIn = dir.lookupName("Foo$.class", directory = false) + val moduleNode = loadClassNode(moduleIn.input) + val methodNode = getMethod(moduleNode, "foo") + assert(verifySwitch(methodNode)) + } + } + + @Test def switchOnUnionOfIntSingletons = { + val source = + """ + |object Foo { + | final val One = 1 + | final val Two = 2 + | final val Three = 3 + | final val Four = 4 + | final val Five = 5 + | type Values = One.type | Two.type | Three.type | Four.type | Five.type + | + | def foo(s: Values) = s match { + | case One => println(3) + | case Two | Three => println(2) + | case Four => println(1) + | case Five => println(0) + | } + |} + """.stripMargin + + checkBCode(source) { dir => + val moduleIn = dir.lookupName("Foo$.class", directory = false) + val moduleNode = loadClassNode(moduleIn.input) + val methodNode = getMethod(moduleNode, "foo") + assert(verifySwitch(methodNode)) + } + } + + @Test def switchOnUnionOfStringSingletons = { + val source = + """ + |object Foo { + | final val One = "one" + | final val Two = "two" + | final val Three = "three" + | final val Four = "four" + | final val Five = "five" + | type Values = One.type | Two.type | Three.type | Four.type | Five.type + | + | def foo(s: Values) = s match { + | case One => println(3) + | case Two | Three => println(2) + | case Four => println(1) + | case Five => println(0) + | } + |} + """.stripMargin + + checkBCode(source) { dir => + val moduleIn = dir.lookupName("Foo$.class", directory = false) + val moduleNode = loadClassNode(moduleIn.input) + val methodNode = getMethod(moduleNode, "foo") + assert(verifySwitch(methodNode)) + } + } + @Test def matchWithDefaultNoThrowMatchError = { val source = """class Test { From 2c349c13ba69ffc13ced832c4e71e316e53b3d63 Mon Sep 17 00:00:00 2001 From: Derek Wickern Date: Thu, 16 May 2024 21:01:28 -0700 Subject: [PATCH 2/3] add test for union of Char --- .../backend/jvm/DottyBytecodeTests.scala | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala index e4e485478804..f80336646dfd 100644 --- a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala +++ b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala @@ -200,6 +200,28 @@ class DottyBytecodeTests extends DottyBytecodeTest { } } + @Test def switchOnUnionOfChars = { + val source = + """ + |object Foo { + | def foo(ch: 'a' | 'b' | 'c' | 'd' | 'e'): Int = ch match { + | case 'a' => 1 + | case 'b' => 2 + | case 'c' => 3 + | case 'd' => 4 + | case 'e' => 5 + | } + |} + """.stripMargin + + checkBCode(source) { dir => + val moduleIn = dir.lookupName("Foo$.class", directory = false) + val moduleNode = loadClassNode(moduleIn.input) + val methodNode = getMethod(moduleNode, "foo") + assert(verifySwitch(methodNode)) + } + } + @Test def switchOnUnionOfIntSingletons = { val source = """ From 7279bf7cb0368b6abef68a8664e63505b22dc7ce Mon Sep 17 00:00:00 2001 From: Derek Wickern Date: Thu, 16 May 2024 21:03:57 -0700 Subject: [PATCH 3/3] replace widen with <:< --- .../tools/dotc/transform/PatternMatcher.scala | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala index 1e95ca1618b2..1e16897081dd 100644 --- a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala +++ b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala @@ -814,11 +814,11 @@ object PatternMatcher { */ private def collectSwitchCases(scrutinee: Tree, plan: SeqPlan): List[(List[Tree], Plan)] = { def isSwitchableType(tpe: Type): Boolean = - (tpe isRef defn.IntClass) || - (tpe isRef defn.ByteClass) || - (tpe isRef defn.ShortClass) || - (tpe isRef defn.CharClass) || - (tpe isRef defn.StringClass) + (tpe <:< defn.IntType) || + (tpe <:< defn.ByteType) || + (tpe <:< defn.ShortType) || + (tpe <:< defn.CharType) || + (tpe <:< defn.StringType) val seen = mutable.Set[Any]() @@ -868,7 +868,7 @@ object PatternMatcher { (Nil, plan) :: Nil } - if (isSwitchableType(scrutinee.tpe.widen.widenSingletons())) recur(plan) + if (isSwitchableType(scrutinee.tpe)) recur(plan) else Nil } @@ -889,9 +889,8 @@ object PatternMatcher { */ val (primScrutinee, scrutineeTpe) = - val tpe = scrutinee.tpe.widen.widenSingletons() - if (tpe.isRef(defn.IntClass)) (scrutinee, defn.IntType) - else if (tpe.isRef(defn.StringClass)) (scrutinee, defn.StringType) + if (scrutinee.tpe <:< defn.IntType) (scrutinee, defn.IntType) + else if (scrutinee.tpe <:< defn.StringType) (scrutinee, defn.StringType) else (scrutinee.select(nme.toInt), defn.IntType) def primLiteral(lit: Tree): Tree =