diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 0483f699fc3d..328cedbad81f 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -422,6 +422,10 @@ class Definitions { def Seq_drop(implicit ctx: Context) = Seq_dropR.symbol lazy val Seq_lengthCompareR = SeqClass.requiredMethodRef(nme.lengthCompare) def Seq_lengthCompare(implicit ctx: Context) = Seq_lengthCompareR.symbol + lazy val Seq_lengthR = SeqClass.requiredMethodRef(nme.length) + def Seq_length(implicit ctx: Context) = Seq_lengthR.symbol + lazy val Seq_toSeqR = SeqClass.requiredMethodRef(nme.toSeq) + def Seq_toSeq(implicit ctx: Context) = Seq_toSeqR.symbol lazy val ArrayType: TypeRef = ctx.requiredClassRef("scala.Array") def ArrayClass(implicit ctx: Context) = ArrayType.symbol.asClass diff --git a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala index 803d234e89dc..98f614ae5a18 100644 --- a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala +++ b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala @@ -254,7 +254,7 @@ object PatternMatcher { */ def matchElemsPlan(seqSym: Symbol, args: List[Tree], exact: Boolean, onSuccess: Plan) = { val selectors = args.indices.toList.map(idx => - ref(seqSym).select(nme.apply).appliedTo(Literal(Constant(idx)))) + ref(seqSym).select(defn.Seq_apply.matchingMember(seqSym.info)).appliedTo(Literal(Constant(idx)))) TestPlan(LengthTest(args.length, exact), seqSym, seqSym.pos, matchArgsPlan(selectors, args, onSuccess)) } @@ -265,8 +265,13 @@ object PatternMatcher { def unapplySeqPlan(getResult: Symbol, args: List[Tree]): Plan = args.lastOption match { case Some(VarArgPattern(arg)) => val matchRemaining = - if (args.length == 1) - patternPlan(getResult, arg, onSuccess) + if (args.length == 1) { + val toSeq = ref(getResult) + .select(defn.Seq_toSeq.matchingMember(getResult.info)) + letAbstract(toSeq) { toSeqResult => + patternPlan(toSeqResult, arg, onSuccess) + } + } else { val dropped = ref(getResult) .select(defn.Seq_drop.matchingMember(getResult.info)) @@ -638,11 +643,18 @@ object PatternMatcher { case EqualTest(tree) => tree.equal(scrutinee) case LengthTest(len, exact) => - scrutinee - .select(defn.Seq_lengthCompare.matchingMember(scrutinee.tpe)) - .appliedTo(Literal(Constant(len))) - .select(if (exact) defn.Int_== else defn.Int_>=) - .appliedTo(Literal(Constant(0))) + val lengthCompareSym = defn.Seq_lengthCompare.matchingMember(scrutinee.tpe) + if (lengthCompareSym.exists) + scrutinee + .select(defn.Seq_lengthCompare.matchingMember(scrutinee.tpe)) + .appliedTo(Literal(Constant(len))) + .select(if (exact) defn.Int_== else defn.Int_>=) + .appliedTo(Literal(Constant(0))) + else // try length + scrutinee + .select(defn.Seq_length.matchingMember(scrutinee.tpe)) + .select(if (exact) defn.Int_== else defn.Int_>=) + .appliedTo(Literal(Constant(len))) case TypeTest(tpt) => val expectedTp = tpt.tpe diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index bce6fb03d7eb..86f1943ab74d 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -100,11 +100,43 @@ object Applications { Nil } + /** If `getType` is of the form: + * ``` + * { + * def lengthCompare(len: Int): Int // or, def length: Int + * def apply(i: Int): T = a(i) + * def drop(n: Int): scala.Seq[T] + * def toSeq: scala.Seq[T] + * } + * ``` + * returns `T`, otherwise NoType. + */ + def unapplySeqTypeElemTp(getTp: Type): Type = { + def lengthTp = ExprType(defn.IntType) + def lengthCompareTp = MethodType(List(defn.IntType), defn.IntType) + def applyTp(elemTp: Type) = MethodType(List(defn.IntType), elemTp) + def dropTp(elemTp: Type) = MethodType(List(defn.IntType), defn.SeqType.appliedTo(elemTp)) + def toSeqTp(elemTp: Type) = ExprType(defn.SeqType.appliedTo(elemTp)) + + // the result type of `def apply(i: Int): T` + val elemTp = getTp.member(nme.apply).suchThat(_.info <:< applyTp(WildcardType)).info.resultType + + def hasMethod(name: Name, tp: Type) = + getTp.member(name).suchThat(getTp.memberInfo(_) <:< tp).exists + + val isValid = + elemTp.exists && + (hasMethod(nme.lengthCompare, lengthCompareTp) || hasMethod(nme.length, lengthTp)) && + hasMethod(nme.drop, dropTp(elemTp)) && + hasMethod(nme.toSeq, toSeqTp(elemTp)) + + if (isValid) elemTp else NoType + } + if (unapplyName == nme.unapplySeq) { - if (unapplyResult derivesFrom defn.SeqClass) seqSelector :: Nil - else if (isGetMatch(unapplyResult, pos) && getTp.derivesFrom(defn.SeqClass)) { - val seqArg = getTp.elemType.hiBound - if (seqArg.exists) args.map(Function.const(seqArg)) + if (isGetMatch(unapplyResult, pos)) { + val elemTp = unapplySeqTypeElemTp(getTp) + if (elemTp.exists) args.map(Function.const(elemTp)) else fail } else fail diff --git a/docs/docs/reference/changed/pattern-matching.md b/docs/docs/reference/changed/pattern-matching.md index 40f5e94201a3..3285f0b04ea7 100644 --- a/docs/docs/reference/changed/pattern-matching.md +++ b/docs/docs/reference/changed/pattern-matching.md @@ -63,12 +63,23 @@ object FirstChars { ``` -## Seq Pattern +## Name-based Seq Pattern - Extractor defines `def unapplySeq(x: T): U` - `U` has (parameterless `def` or `val`) members `isEmpty: Boolean` and `get: S` -- `S <: Seq[V]` -- Pattern-matching on `N` pattern with types `V, V, ..., V`, where `N` is the runtime size of the `Seq`. +- `S` conforms to `X`, `T2` and `T3` conform to `T1` + +```Scala +type X = { + def lengthCompare(len: Int): Int // or, `def length: Int` + def apply(i: Int): T1 + def drop(n: Int): scala.Seq[T2] + def toSeq: scala.Seq[T3] +} +``` + +- Pattern-matching on _exactly_ `N` simple patterns with types `T1, T1, ..., T1`, where `N` is the runtime size of the sequence, or +- Pattern-matching on `>= N` simple patterns and _a vararg pattern_ (e.g., `xs: _*`) with types `T1, T1, ..., T1, Seq[T1]`, where `N` is the minimum size of the sequence. @@ -87,7 +98,7 @@ object CharList { ``` -## Name Based Pattern +## Name-based Pattern - Extractor defines `def unapply(x: T): U` - `U` has (parameterless `def` or `val`) members `isEmpty: Boolean` and `get: S` diff --git a/tests/neg/i4984.scala b/tests/neg/i4984.scala new file mode 100644 index 000000000000..fd35940e0731 --- /dev/null +++ b/tests/neg/i4984.scala @@ -0,0 +1,38 @@ +object Array2 { + def unapplySeq(x: Array[Int]): Data = new Data + class Data { + def isEmpty: Boolean = false + def get: Data = this + def lengthCompare(len: Int): Int = 0 + def apply(i: Int): Int = 3 + // drop return type, not conforming to apply's + def drop(n: Int): scala.Seq[String] = Seq("hello") + def toSeq: scala.Seq[Int] = Seq(6, 7) + } +} + +object Array3 { + def unapplySeq(x: Array[Int]): Data = new Data + class Data { + def isEmpty: Boolean = false + def get: Data = this + def lengthCompare(len: Int): Int = 0 + // missing apply + def drop(n: Int): scala.Seq[Int] = ??? + def toSeq: scala.Seq[Int] = ??? + } +} + +object Test { + def test(xs: Array[Int]): Int = xs match { + case Array2(x, y) => 1 // error + case Array2(x, y, xs: _*) => 2 // error + case Array2(xs: _*) => 3 // error + } + + def test2(xs: Array[Int]): Int = xs match { + case Array3(x, y) => 1 // error + case Array3(x, y, xs: _*) => 2 // error + case Array3(xs: _*) => 3 // error + } +} diff --git a/tests/pos/i4984.scala b/tests/pos/i4984.scala new file mode 100644 index 000000000000..f81f068113ed --- /dev/null +++ b/tests/pos/i4984.scala @@ -0,0 +1,26 @@ +object Array2 { + def unapplySeq[T](x: Array[T]): UnapplySeqWrapper[T] = new UnapplySeqWrapper(x) + + final class UnapplySeqWrapper[T](private val a: Array[T]) extends AnyVal { + def isEmpty: Boolean = false + def get: UnapplySeqWrapper[T] = this + def lengthCompare(len: Int): Int = a.lengthCompare(len) + def apply(i: Int): T = a(i) + def drop(n: Int): scala.Seq[T] = ??? + def toSeq: scala.Seq[T] = a.toSeq // clones the array + } +} + +class Test { + def test1(xs: Array[Int]): Int = xs match { + case Array2(x, y) => x + y + } + + def test2(xs: Array[Int]): Seq[Int] = xs match { + case Array2(x, y, xs:_*) => xs + } + + def test3(xs: Array[Int]): Seq[Int] = xs match { + case Array2(xs:_*) => xs + } +} diff --git a/tests/run/i4984b.scala b/tests/run/i4984b.scala new file mode 100644 index 000000000000..7f9beb62128c --- /dev/null +++ b/tests/run/i4984b.scala @@ -0,0 +1,32 @@ +object Array2 { + def unapplySeq(x: Array[Int]): Data = new Data + + final class Data { + def isEmpty: Boolean = false + def get: Data = this + def lengthCompare(len: Int): Int = 0 + def apply(i: Int): Int = 3 + def drop(n: Int): scala.Seq[Int] = Seq(2, 5) + def toSeq: scala.Seq[Int] = Seq(6, 7) + } +} + +object Test { + def test1(xs: Array[Int]): Int = xs match { + case Array2(x, y) => x + y + } + + def test2(xs: Array[Int]): Seq[Int] = xs match { + case Array2(x, y, xs:_*) => xs + } + + def test3(xs: Array[Int]): Seq[Int] = xs match { + case Array2(xs:_*) => xs + } + + def main(args: Array[String]): Unit = { + test1(Array(3, 5)) + test2(Array(3, 5)) + test3(Array(3, 5)) + } +} diff --git a/tests/run/i4984c.scala b/tests/run/i4984c.scala new file mode 100644 index 000000000000..cd23936ca209 --- /dev/null +++ b/tests/run/i4984c.scala @@ -0,0 +1,32 @@ +object Array2 { + def unapplySeq(x: Array[Int]): Data = new Data + + final class Data { + def isEmpty: Boolean = false + def get: Data = this + def length: Int = 2 + def apply(i: Int): Int = 3 + def drop(n: Int): scala.Seq[Int] = Seq(2, 5) + def toSeq: scala.Seq[Int] = Seq(6, 7) + } +} + +object Test { + def test1(xs: Array[Int]): Int = xs match { + case Array2(x, y) => x + y + } + + def test2(xs: Array[Int]): Seq[Int] = xs match { + case Array2(x, y, xs:_*) => xs + } + + def test3(xs: Array[Int]): Seq[Int] = xs match { + case Array2(xs:_*) => xs + } + + def main(args: Array[String]): Unit = { + test1(Array(3, 5)) + test2(Array(3, 5)) + test3(Array(3, 5)) + } +} diff --git a/tests/run/i4984d.scala b/tests/run/i4984d.scala new file mode 100644 index 000000000000..b53cc8585923 --- /dev/null +++ b/tests/run/i4984d.scala @@ -0,0 +1,35 @@ +object Array2 { + def unapplySeq(x: Array[Int]): Data1 = new Data1 + + class Data1 { + def isEmpty: Boolean = false + def get: Data2 = new Data2 + } + + class Data2 { + def apply(i: Int): Int = 3 + def drop(n: Int): scala.Seq[Int] = Seq(2, 5) + def toSeq: scala.Seq[Int] = Seq(6, 7) + def lengthCompare(len: Int): Int = 0 + } +} + +object Test { + def test1(xs: Array[Int]): Int = xs match { + case Array2(x, y) => x + y + } + + def test2(xs: Array[Int]): Seq[Int] = xs match { + case Array2(x, y, xs:_*) => xs + } + + def test3(xs: Array[Int]): Seq[Int] = xs match { + case Array2(xs:_*) => xs + } + + def main(args: Array[String]): Unit = { + test1(Array(3, 5)) + test2(Array(3, 5)) + test3(Array(3, 5)) + } +} diff --git a/tests/run/i4984e.scala b/tests/run/i4984e.scala new file mode 100644 index 000000000000..dfc4727f08af --- /dev/null +++ b/tests/run/i4984e.scala @@ -0,0 +1,36 @@ +object Array2 { + def unapplySeq(x: Array[Int]): Data = new Data + + final class Data { + def isEmpty: Boolean = false + def get: Data = this + def lengthCompare(len: Int): Int = 0 + def lengthCompare: Int = 0 + def apply(i: Int): Int = 3 + def apply(i: String): Int = 3 + def drop(n: Int): scala.Seq[Int] = Seq(2, 5) + def drop: scala.Seq[Int] = Seq(2, 5) + def toSeq: scala.Seq[Int] = Seq(6, 7) + def toSeq(x: Int): scala.Seq[Int] = Seq(6, 7) + } +} + +object Test { + def test1(xs: Array[Int]): Int = xs match { + case Array2(x, y) => x + y + } + + def test2(xs: Array[Int]): Seq[Int] = xs match { + case Array2(x, y, xs:_*) => xs + } + + def test3(xs: Array[Int]): Seq[Int] = xs match { + case Array2(xs:_*) => xs + } + + def main(args: Array[String]): Unit = { + test1(Array(3, 5)) + test2(Array(3, 5)) + test3(Array(3, 5)) + } +}