Skip to content

Commit

Permalink
Merge pull request #5078 from dotty-staging/fix-4984
Browse files Browse the repository at this point in the history
Fix #4984: support name-based unapplySeq
  • Loading branch information
odersky committed Sep 20, 2018
2 parents e34bb2d + 6a0fc8c commit a2a1112
Show file tree
Hide file tree
Showing 10 changed files with 274 additions and 16 deletions.
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,10 @@ class Definitions {
def Seq_drop(implicit ctx: Context) = Seq_dropR.symbol
lazy val Seq_lengthCompareR = SeqClass.requiredMethodRef(nme.lengthCompare)
def Seq_lengthCompare(implicit ctx: Context) = Seq_lengthCompareR.symbol
lazy val Seq_lengthR = SeqClass.requiredMethodRef(nme.length)
def Seq_length(implicit ctx: Context) = Seq_lengthR.symbol
lazy val Seq_toSeqR = SeqClass.requiredMethodRef(nme.toSeq)
def Seq_toSeq(implicit ctx: Context) = Seq_toSeqR.symbol

lazy val ArrayType: TypeRef = ctx.requiredClassRef("scala.Array")
def ArrayClass(implicit ctx: Context) = ArrayType.symbol.asClass
Expand Down
28 changes: 20 additions & 8 deletions compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ object PatternMatcher {
*/
def matchElemsPlan(seqSym: Symbol, args: List[Tree], exact: Boolean, onSuccess: Plan) = {
val selectors = args.indices.toList.map(idx =>
ref(seqSym).select(nme.apply).appliedTo(Literal(Constant(idx))))
ref(seqSym).select(defn.Seq_apply.matchingMember(seqSym.info)).appliedTo(Literal(Constant(idx))))
TestPlan(LengthTest(args.length, exact), seqSym, seqSym.pos,
matchArgsPlan(selectors, args, onSuccess))
}
Expand All @@ -265,8 +265,13 @@ object PatternMatcher {
def unapplySeqPlan(getResult: Symbol, args: List[Tree]): Plan = args.lastOption match {
case Some(VarArgPattern(arg)) =>
val matchRemaining =
if (args.length == 1)
patternPlan(getResult, arg, onSuccess)
if (args.length == 1) {
val toSeq = ref(getResult)
.select(defn.Seq_toSeq.matchingMember(getResult.info))
letAbstract(toSeq) { toSeqResult =>
patternPlan(toSeqResult, arg, onSuccess)
}
}
else {
val dropped = ref(getResult)
.select(defn.Seq_drop.matchingMember(getResult.info))
Expand Down Expand Up @@ -638,11 +643,18 @@ object PatternMatcher {
case EqualTest(tree) =>
tree.equal(scrutinee)
case LengthTest(len, exact) =>
scrutinee
.select(defn.Seq_lengthCompare.matchingMember(scrutinee.tpe))
.appliedTo(Literal(Constant(len)))
.select(if (exact) defn.Int_== else defn.Int_>=)
.appliedTo(Literal(Constant(0)))
val lengthCompareSym = defn.Seq_lengthCompare.matchingMember(scrutinee.tpe)
if (lengthCompareSym.exists)
scrutinee
.select(defn.Seq_lengthCompare.matchingMember(scrutinee.tpe))
.appliedTo(Literal(Constant(len)))
.select(if (exact) defn.Int_== else defn.Int_>=)
.appliedTo(Literal(Constant(0)))
else // try length
scrutinee
.select(defn.Seq_length.matchingMember(scrutinee.tpe))
.select(if (exact) defn.Int_== else defn.Int_>=)
.appliedTo(Literal(Constant(len)))
case TypeTest(tpt) =>
val expectedTp = tpt.tpe

Expand Down
40 changes: 36 additions & 4 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,43 @@ object Applications {
Nil
}

/** If `getType` is of the form:
* ```
* {
* def lengthCompare(len: Int): Int // or, def length: Int
* def apply(i: Int): T = a(i)
* def drop(n: Int): scala.Seq[T]
* def toSeq: scala.Seq[T]
* }
* ```
* returns `T`, otherwise NoType.
*/
def unapplySeqTypeElemTp(getTp: Type): Type = {
def lengthTp = ExprType(defn.IntType)
def lengthCompareTp = MethodType(List(defn.IntType), defn.IntType)
def applyTp(elemTp: Type) = MethodType(List(defn.IntType), elemTp)
def dropTp(elemTp: Type) = MethodType(List(defn.IntType), defn.SeqType.appliedTo(elemTp))
def toSeqTp(elemTp: Type) = ExprType(defn.SeqType.appliedTo(elemTp))

// the result type of `def apply(i: Int): T`
val elemTp = getTp.member(nme.apply).suchThat(_.info <:< applyTp(WildcardType)).info.resultType

def hasMethod(name: Name, tp: Type) =
getTp.member(name).suchThat(getTp.memberInfo(_) <:< tp).exists

val isValid =
elemTp.exists &&
(hasMethod(nme.lengthCompare, lengthCompareTp) || hasMethod(nme.length, lengthTp)) &&
hasMethod(nme.drop, dropTp(elemTp)) &&
hasMethod(nme.toSeq, toSeqTp(elemTp))

if (isValid) elemTp else NoType
}

if (unapplyName == nme.unapplySeq) {
if (unapplyResult derivesFrom defn.SeqClass) seqSelector :: Nil
else if (isGetMatch(unapplyResult, pos) && getTp.derivesFrom(defn.SeqClass)) {
val seqArg = getTp.elemType.hiBound
if (seqArg.exists) args.map(Function.const(seqArg))
if (isGetMatch(unapplyResult, pos)) {
val elemTp = unapplySeqTypeElemTp(getTp)
if (elemTp.exists) args.map(Function.const(elemTp))
else fail
}
else fail
Expand Down
19 changes: 15 additions & 4 deletions docs/docs/reference/changed/pattern-matching.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,23 @@ object FirstChars {
```


## Seq Pattern
## Name-based Seq Pattern

- Extractor defines `def unapplySeq(x: T): U`
- `U` has (parameterless `def` or `val`) members `isEmpty: Boolean` and `get: S`
- `S <: Seq[V]`
- Pattern-matching on `N` pattern with types `V, V, ..., V`, where `N` is the runtime size of the `Seq`.
- `S` conforms to `X`, `T2` and `T3` conform to `T1`

```Scala
type X = {
def lengthCompare(len: Int): Int // or, `def length: Int`
def apply(i: Int): T1
def drop(n: Int): scala.Seq[T2]
def toSeq: scala.Seq[T3]
}
```

- Pattern-matching on _exactly_ `N` simple patterns with types `T1, T1, ..., T1`, where `N` is the runtime size of the sequence, or
- Pattern-matching on `>= N` simple patterns and _a vararg pattern_ (e.g., `xs: _*`) with types `T1, T1, ..., T1, Seq[T1]`, where `N` is the minimum size of the sequence.

<!-- To be kept in sync with tests/new/patmat-spec.scala -->

Expand All @@ -87,7 +98,7 @@ object CharList {
```


## Name Based Pattern
## Name-based Pattern

- Extractor defines `def unapply(x: T): U`
- `U` has (parameterless `def` or `val`) members `isEmpty: Boolean` and `get: S`
Expand Down
38 changes: 38 additions & 0 deletions tests/neg/i4984.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
object Array2 {
def unapplySeq(x: Array[Int]): Data = new Data
class Data {
def isEmpty: Boolean = false
def get: Data = this
def lengthCompare(len: Int): Int = 0
def apply(i: Int): Int = 3
// drop return type, not conforming to apply's
def drop(n: Int): scala.Seq[String] = Seq("hello")
def toSeq: scala.Seq[Int] = Seq(6, 7)
}
}

object Array3 {
def unapplySeq(x: Array[Int]): Data = new Data
class Data {
def isEmpty: Boolean = false
def get: Data = this
def lengthCompare(len: Int): Int = 0
// missing apply
def drop(n: Int): scala.Seq[Int] = ???
def toSeq: scala.Seq[Int] = ???
}
}

object Test {
def test(xs: Array[Int]): Int = xs match {
case Array2(x, y) => 1 // error
case Array2(x, y, xs: _*) => 2 // error
case Array2(xs: _*) => 3 // error
}

def test2(xs: Array[Int]): Int = xs match {
case Array3(x, y) => 1 // error
case Array3(x, y, xs: _*) => 2 // error
case Array3(xs: _*) => 3 // error
}
}
26 changes: 26 additions & 0 deletions tests/pos/i4984.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
object Array2 {
def unapplySeq[T](x: Array[T]): UnapplySeqWrapper[T] = new UnapplySeqWrapper(x)

final class UnapplySeqWrapper[T](private val a: Array[T]) extends AnyVal {
def isEmpty: Boolean = false
def get: UnapplySeqWrapper[T] = this
def lengthCompare(len: Int): Int = a.lengthCompare(len)
def apply(i: Int): T = a(i)
def drop(n: Int): scala.Seq[T] = ???
def toSeq: scala.Seq[T] = a.toSeq // clones the array
}
}

class Test {
def test1(xs: Array[Int]): Int = xs match {
case Array2(x, y) => x + y
}

def test2(xs: Array[Int]): Seq[Int] = xs match {
case Array2(x, y, xs:_*) => xs
}

def test3(xs: Array[Int]): Seq[Int] = xs match {
case Array2(xs:_*) => xs
}
}
32 changes: 32 additions & 0 deletions tests/run/i4984b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
object Array2 {
def unapplySeq(x: Array[Int]): Data = new Data

final class Data {
def isEmpty: Boolean = false
def get: Data = this
def lengthCompare(len: Int): Int = 0
def apply(i: Int): Int = 3
def drop(n: Int): scala.Seq[Int] = Seq(2, 5)
def toSeq: scala.Seq[Int] = Seq(6, 7)
}
}

object Test {
def test1(xs: Array[Int]): Int = xs match {
case Array2(x, y) => x + y
}

def test2(xs: Array[Int]): Seq[Int] = xs match {
case Array2(x, y, xs:_*) => xs
}

def test3(xs: Array[Int]): Seq[Int] = xs match {
case Array2(xs:_*) => xs
}

def main(args: Array[String]): Unit = {
test1(Array(3, 5))
test2(Array(3, 5))
test3(Array(3, 5))
}
}
32 changes: 32 additions & 0 deletions tests/run/i4984c.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
object Array2 {
def unapplySeq(x: Array[Int]): Data = new Data

final class Data {
def isEmpty: Boolean = false
def get: Data = this
def length: Int = 2
def apply(i: Int): Int = 3
def drop(n: Int): scala.Seq[Int] = Seq(2, 5)
def toSeq: scala.Seq[Int] = Seq(6, 7)
}
}

object Test {
def test1(xs: Array[Int]): Int = xs match {
case Array2(x, y) => x + y
}

def test2(xs: Array[Int]): Seq[Int] = xs match {
case Array2(x, y, xs:_*) => xs
}

def test3(xs: Array[Int]): Seq[Int] = xs match {
case Array2(xs:_*) => xs
}

def main(args: Array[String]): Unit = {
test1(Array(3, 5))
test2(Array(3, 5))
test3(Array(3, 5))
}
}
35 changes: 35 additions & 0 deletions tests/run/i4984d.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
object Array2 {
def unapplySeq(x: Array[Int]): Data1 = new Data1

class Data1 {
def isEmpty: Boolean = false
def get: Data2 = new Data2
}

class Data2 {
def apply(i: Int): Int = 3
def drop(n: Int): scala.Seq[Int] = Seq(2, 5)
def toSeq: scala.Seq[Int] = Seq(6, 7)
def lengthCompare(len: Int): Int = 0
}
}

object Test {
def test1(xs: Array[Int]): Int = xs match {
case Array2(x, y) => x + y
}

def test2(xs: Array[Int]): Seq[Int] = xs match {
case Array2(x, y, xs:_*) => xs
}

def test3(xs: Array[Int]): Seq[Int] = xs match {
case Array2(xs:_*) => xs
}

def main(args: Array[String]): Unit = {
test1(Array(3, 5))
test2(Array(3, 5))
test3(Array(3, 5))
}
}
36 changes: 36 additions & 0 deletions tests/run/i4984e.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
object Array2 {
def unapplySeq(x: Array[Int]): Data = new Data

final class Data {
def isEmpty: Boolean = false
def get: Data = this
def lengthCompare(len: Int): Int = 0
def lengthCompare: Int = 0
def apply(i: Int): Int = 3
def apply(i: String): Int = 3
def drop(n: Int): scala.Seq[Int] = Seq(2, 5)
def drop: scala.Seq[Int] = Seq(2, 5)
def toSeq: scala.Seq[Int] = Seq(6, 7)
def toSeq(x: Int): scala.Seq[Int] = Seq(6, 7)
}
}

object Test {
def test1(xs: Array[Int]): Int = xs match {
case Array2(x, y) => x + y
}

def test2(xs: Array[Int]): Seq[Int] = xs match {
case Array2(x, y, xs:_*) => xs
}

def test3(xs: Array[Int]): Seq[Int] = xs match {
case Array2(xs:_*) => xs
}

def main(args: Array[String]): Unit = {
test1(Array(3, 5))
test2(Array(3, 5))
test3(Array(3, 5))
}
}

0 comments on commit a2a1112

Please sign in to comment.