Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #4984: support name-based unapplySeq #5078

Merged
merged 10 commits into from
Sep 20, 2018
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,10 @@ class Definitions {
def Seq_drop(implicit ctx: Context) = Seq_dropR.symbol
lazy val Seq_lengthCompareR = SeqClass.requiredMethodRef(nme.lengthCompare)
def Seq_lengthCompare(implicit ctx: Context) = Seq_lengthCompareR.symbol
lazy val Seq_lengthR = SeqClass.requiredMethodRef(nme.length)
def Seq_length(implicit ctx: Context) = Seq_lengthR.symbol
lazy val Seq_toSeqR = SeqClass.requiredMethodRef(nme.toSeq)
def Seq_toSeq(implicit ctx: Context) = Seq_toSeqR.symbol

lazy val ArrayType: TypeRef = ctx.requiredClassRef("scala.Array")
def ArrayClass(implicit ctx: Context) = ArrayType.symbol.asClass
Expand Down
28 changes: 20 additions & 8 deletions compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ object PatternMatcher {
*/
def matchElemsPlan(seqSym: Symbol, args: List[Tree], exact: Boolean, onSuccess: Plan) = {
val selectors = args.indices.toList.map(idx =>
ref(seqSym).select(nme.apply).appliedTo(Literal(Constant(idx))))
ref(seqSym).select(defn.Seq_apply.matchingMember(seqSym.info)).appliedTo(Literal(Constant(idx))))
TestPlan(LengthTest(args.length, exact), seqSym, seqSym.pos,
matchArgsPlan(selectors, args, onSuccess))
}
Expand All @@ -265,8 +265,13 @@ object PatternMatcher {
def unapplySeqPlan(getResult: Symbol, args: List[Tree]): Plan = args.lastOption match {
case Some(VarArgPattern(arg)) =>
val matchRemaining =
if (args.length == 1)
patternPlan(getResult, arg, onSuccess)
if (args.length == 1) {
val toSeq = ref(getResult)
.select(defn.Seq_toSeq.matchingMember(getResult.info))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not simply: .select(nme.toSeq)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is defensive here to support possible overloading of toSeq.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add tests for this then. Would it work to do .select(nme.toSeq).appliedToNone()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My hunch is that it will crash the compiler, as appliedToNone will not do overloading resolution.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But there is no need right? You cannot have two overloads of a method that takes no parameter

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind, your solution seems better. matchingMember is already used in multiple places in this file (I suppose for reasons)

letAbstract(toSeq) { toSeqResult =>
patternPlan(toSeqResult, arg, onSuccess)
}
}
else {
val dropped = ref(getResult)
.select(defn.Seq_drop.matchingMember(getResult.info))
Expand Down Expand Up @@ -638,11 +643,18 @@ object PatternMatcher {
case EqualTest(tree) =>
tree.equal(scrutinee)
case LengthTest(len, exact) =>
scrutinee
.select(defn.Seq_lengthCompare.matchingMember(scrutinee.tpe))
.appliedTo(Literal(Constant(len)))
.select(if (exact) defn.Int_== else defn.Int_>=)
.appliedTo(Literal(Constant(0)))
val lengthCompareSym = defn.Seq_lengthCompare.matchingMember(scrutinee.tpe)
if (lengthCompareSym.exists)
scrutinee
.select(defn.Seq_lengthCompare.matchingMember(scrutinee.tpe))
.appliedTo(Literal(Constant(len)))
.select(if (exact) defn.Int_== else defn.Int_>=)
.appliedTo(Literal(Constant(0)))
else // try length
scrutinee
.select(defn.Seq_length.matchingMember(scrutinee.tpe))
.select(if (exact) defn.Int_== else defn.Int_>=)
.appliedTo(Literal(Constant(len)))
case TypeTest(tpt) =>
val expectedTp = tpt.tpe

Expand Down
40 changes: 36 additions & 4 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,43 @@ object Applications {
Nil
}

/** If `getType` is of the form:
* ```
* {
* def lengthCompare(len: Int): Int // or, def length: Int
* def apply(i: Int): T = a(i)
* def drop(n: Int): scala.Seq[T]
* def toSeq: scala.Seq[T]
* }
* ```
* returns `T`, otherwise NoType.
*/
def unapplySeqTypeElemTp(getTp: Type): Type = {
def lengthTp = ExprType(defn.IntType)
def lengthCompareTp = MethodType(List(defn.IntType), defn.IntType)
def applyTp(elemTp: Type) = MethodType(List(defn.IntType), elemTp)
def dropTp(elemTp: Type) = MethodType(List(defn.IntType), defn.SeqType.appliedTo(elemTp))
def toSeqTp(elemTp: Type) = ExprType(defn.SeqType.appliedTo(elemTp))

// the result type of `def apply(i: Int): T`
val elemTp = getTp.member(nme.apply).suchThat(_.info <:< applyTp(WildcardType)).info.resultType

def hasMethod(name: Name, tp: Type) =
getTp.member(name).suchThat(getTp.memberInfo(_) <:< tp).exists

val isValid =
elemTp.exists &&
(hasMethod(nme.lengthCompare, lengthCompareTp) || hasMethod(nme.length, lengthTp)) &&
hasMethod(nme.drop, dropTp(elemTp)) &&
hasMethod(nme.toSeq, toSeqTp(elemTp))

if (isValid) elemTp else NoType
}

if (unapplyName == nme.unapplySeq) {
if (unapplyResult derivesFrom defn.SeqClass) seqSelector :: Nil
else if (isGetMatch(unapplyResult, pos) && getTp.derivesFrom(defn.SeqClass)) {
val seqArg = getTp.elemType.hiBound
if (seqArg.exists) args.map(Function.const(seqArg))
if (isGetMatch(unapplyResult, pos)) {
val elemTp = unapplySeqTypeElemTp(getTp)
if (elemTp.exists) args.map(Function.const(elemTp))
else fail
}
else fail
Expand Down
19 changes: 15 additions & 4 deletions docs/docs/reference/changed/pattern-matching.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,23 @@ object FirstChars {
```


## Seq Pattern
## Name-based Seq Pattern

- Extractor defines `def unapplySeq(x: T): U`
- `U` has (parameterless `def` or `val`) members `isEmpty: Boolean` and `get: S`
- `S <: Seq[V]`
- Pattern-matching on `N` pattern with types `V, V, ..., V`, where `N` is the runtime size of the `Seq`.
- `S` conforms to `X`, `T2` and `T3` conform to `T1`

```Scala
type X = {
def lengthCompare(len: Int): Int // or, `def length: Int`
def apply(i: Int): T1
def drop(n: Int): scala.Seq[T2]
def toSeq: scala.Seq[T3]
}
```

- Pattern-matching on _exactly_ `N` simple patterns with types `T1, T1, ..., T1`, where `N` is the runtime size of the sequence, or
- Pattern-matching on `>= N` simple patterns and _a vararg pattern_ (e.g., `xs: _*`) with types `T1, T1, ..., T1, Seq[T1]`, where `N` is the minimum size of the sequence.

<!-- To be kept in sync with tests/new/patmat-spec.scala -->

Expand All @@ -87,7 +98,7 @@ object CharList {
```


## Name Based Pattern
## Name-based Pattern

- Extractor defines `def unapply(x: T): U`
- `U` has (parameterless `def` or `val`) members `isEmpty: Boolean` and `get: S`
Expand Down
38 changes: 38 additions & 0 deletions tests/neg/i4984.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
object Array2 {
def unapplySeq(x: Array[Int]): Data = new Data
class Data {
def isEmpty: Boolean = false
def get: Data = this
def lengthCompare(len: Int): Int = 0
def apply(i: Int): Int = 3
// drop return type, not conforming to apply's
def drop(n: Int): scala.Seq[String] = Seq("hello")
def toSeq: scala.Seq[Int] = Seq(6, 7)
}
}

object Array3 {
def unapplySeq(x: Array[Int]): Data = new Data
class Data {
def isEmpty: Boolean = false
def get: Data = this
def lengthCompare(len: Int): Int = 0
// missing apply
def drop(n: Int): scala.Seq[Int] = ???
def toSeq: scala.Seq[Int] = ???
}
}

object Test {
def test(xs: Array[Int]): Int = xs match {
case Array2(x, y) => 1 // error
case Array2(x, y, xs: _*) => 2 // error
case Array2(xs: _*) => 3 // error
}

def test2(xs: Array[Int]): Int = xs match {
case Array3(x, y) => 1 // error
case Array3(x, y, xs: _*) => 2 // error
case Array3(xs: _*) => 3 // error
}
}
26 changes: 26 additions & 0 deletions tests/pos/i4984.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
object Array2 {
def unapplySeq[T](x: Array[T]): UnapplySeqWrapper[T] = new UnapplySeqWrapper(x)

final class UnapplySeqWrapper[T](private val a: Array[T]) extends AnyVal {
def isEmpty: Boolean = false
def get: UnapplySeqWrapper[T] = this
def lengthCompare(len: Int): Int = a.lengthCompare(len)
def apply(i: Int): T = a(i)
def drop(n: Int): scala.Seq[T] = ???
def toSeq: scala.Seq[T] = a.toSeq // clones the array
}
}

class Test {
def test1(xs: Array[Int]): Int = xs match {
case Array2(x, y) => x + y
}

def test2(xs: Array[Int]): Seq[Int] = xs match {
case Array2(x, y, xs:_*) => xs
}

def test3(xs: Array[Int]): Seq[Int] = xs match {
case Array2(xs:_*) => xs
}
}
32 changes: 32 additions & 0 deletions tests/run/i4984b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
object Array2 {
def unapplySeq(x: Array[Int]): Data = new Data

final class Data {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also a add test with:

class Data1 {
    def isEmpty: Boolean = false
    def get: Data2 = new Data2
    def lengthCompare(len: Int): Int = 0
  }
}

class Data2 {
    def apply(i: Int): Int = 3
    def drop(n: Int): scala.Seq[Int] = Seq(2, 5)
    def toSeq: scala.Seq[Int] = Seq(6, 7)
}

def isEmpty: Boolean = false
def get: Data = this
def lengthCompare(len: Int): Int = 0
def apply(i: Int): Int = 3
def drop(n: Int): scala.Seq[Int] = Seq(2, 5)
def toSeq: scala.Seq[Int] = Seq(6, 7)
}
}

object Test {
def test1(xs: Array[Int]): Int = xs match {
case Array2(x, y) => x + y
}

def test2(xs: Array[Int]): Seq[Int] = xs match {
case Array2(x, y, xs:_*) => xs
}

def test3(xs: Array[Int]): Seq[Int] = xs match {
case Array2(xs:_*) => xs
}

def main(args: Array[String]): Unit = {
test1(Array(3, 5))
test2(Array(3, 5))
test3(Array(3, 5))
}
}
32 changes: 32 additions & 0 deletions tests/run/i4984c.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
object Array2 {
def unapplySeq(x: Array[Int]): Data = new Data

final class Data {
def isEmpty: Boolean = false
def get: Data = this
def length: Int = 2
def apply(i: Int): Int = 3
def drop(n: Int): scala.Seq[Int] = Seq(2, 5)
def toSeq: scala.Seq[Int] = Seq(6, 7)
}
}

object Test {
def test1(xs: Array[Int]): Int = xs match {
case Array2(x, y) => x + y
}

def test2(xs: Array[Int]): Seq[Int] = xs match {
case Array2(x, y, xs:_*) => xs
}

def test3(xs: Array[Int]): Seq[Int] = xs match {
case Array2(xs:_*) => xs
}

def main(args: Array[String]): Unit = {
test1(Array(3, 5))
test2(Array(3, 5))
test3(Array(3, 5))
}
}
35 changes: 35 additions & 0 deletions tests/run/i4984d.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
object Array2 {
def unapplySeq(x: Array[Int]): Data1 = new Data1

class Data1 {
def isEmpty: Boolean = false
def get: Data2 = new Data2
}

class Data2 {
def apply(i: Int): Int = 3
def drop(n: Int): scala.Seq[Int] = Seq(2, 5)
def toSeq: scala.Seq[Int] = Seq(6, 7)
def lengthCompare(len: Int): Int = 0
}
}

object Test {
def test1(xs: Array[Int]): Int = xs match {
case Array2(x, y) => x + y
}

def test2(xs: Array[Int]): Seq[Int] = xs match {
case Array2(x, y, xs:_*) => xs
}

def test3(xs: Array[Int]): Seq[Int] = xs match {
case Array2(xs:_*) => xs
}

def main(args: Array[String]): Unit = {
test1(Array(3, 5))
test2(Array(3, 5))
test3(Array(3, 5))
}
}
36 changes: 36 additions & 0 deletions tests/run/i4984e.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
object Array2 {
def unapplySeq(x: Array[Int]): Data = new Data

final class Data {
def isEmpty: Boolean = false
def get: Data = this
def lengthCompare(len: Int): Int = 0
def lengthCompare: Int = 0
def apply(i: Int): Int = 3
def apply(i: String): Int = 3
def drop(n: Int): scala.Seq[Int] = Seq(2, 5)
def drop: scala.Seq[Int] = Seq(2, 5)
def toSeq: scala.Seq[Int] = Seq(6, 7)
def toSeq(x: Int): scala.Seq[Int] = Seq(6, 7)
}
}

object Test {
def test1(xs: Array[Int]): Int = xs match {
case Array2(x, y) => x + y
}

def test2(xs: Array[Int]): Seq[Int] = xs match {
case Array2(x, y, xs:_*) => xs
}

def test3(xs: Array[Int]): Seq[Int] = xs match {
case Array2(xs:_*) => xs
}

def main(args: Array[String]): Unit = {
test1(Array(3, 5))
test2(Array(3, 5))
test3(Array(3, 5))
}
}