Skip to content

Commit

Permalink
Explicitly call toSeq on wildcard-star patterns
Browse files Browse the repository at this point in the history
With name-based pattern matching, unapplySeq can return any type that
has an `isEmpty` and `get` methods. The object returned by `get` needs
to have `apply`, `length` or `lengthCompare` and `drop` methods.

This PR changes the type of `x @ _*` bound to `scala.Seq`. To support
that change, the object returned by `unapplySeq.get` is converted by
calling `.toSeq` or `drop(n)`.

This means there are two changes in the interface for name-based pattern
matching:
  - the object needs to define a `toSeq` method
  - the `drop` method needs to return a `scala.Seq`

The `unapplySeq` method defined in `collection.SeqFactory` now returns
a value class wrapper that delegates to the collection. `toSeq` no
longer exposes mutable collections.

`Array.unapplySeq` uses a similar value class wrapper.
  • Loading branch information
lrytz authored and SethTisue committed Aug 21, 2018
1 parent b8e609b commit 308ae2d
Show file tree
Hide file tree
Showing 15 changed files with 582 additions and 67 deletions.
13 changes: 1 addition & 12 deletions src/compiler/scala/tools/nsc/transform/patmat/MatchCodeGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,7 @@ trait MatchCodeGen extends Interface {
def fun(arg: Symbol, body: Tree): Tree = Function(List(ValDef(arg)), body)
def tupleSel(binder: Symbol)(i: Int): Tree = (REF(binder) DOT nme.productAccessorName(i)) // make tree that accesses the i'th component of the tuple referenced by binder
def index(tgt: Tree)(i: Int): Tree = tgt APPLY (LIT(i))

// Right now this blindly calls drop on the result of the unapplySeq
// unless it verifiably has no drop method (this is the case in particular
// with Array.) You should not actually have to write a method called drop
// for name-based matching, but this was an expedient route for the basics.
def drop(tgt: Tree)(n: Int): Tree = {
def callDirect = fn(tgt, nme.drop, LIT(n))
def callRuntime = Apply(REF(currentRun.runDefinitions.traversableDropMethod), tgt :: LIT(n) :: Nil)
def needsRuntime = (tgt.tpe ne null) && (elementTypeFromDrop(tgt.tpe) == NoType)

if (needsRuntime) callRuntime else callDirect
}
def drop(tgt: Tree)(n: Int): Tree = fn(tgt, nme.drop, LIT(n))

// NOTE: checker must be the target of the ==, that's the patmat semantics for ya
def _equals(checker: Tree, binder: Symbol): Tree = checker MEMBER_== REF(binder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,22 +409,22 @@ trait MatchTranslation {
protected def expectedLength = elementArity
protected def lastIndexingBinder = nonStarArity - 1

private def productElemsToN(binder: Symbol, n: Int): List[Tree] = 1 to n map tupleSel(binder) toList
private def genTake(binder: Symbol, n: Int): List[Tree] = (0 until n).toList map (codegen index seqTree(binder))
private def genDrop(binder: Symbol, n: Int): List[Tree] = codegen.drop(seqTree(binder))(expectedLength) :: Nil
private def productElemsToN(binder: Symbol, n: Int): List[Tree] = if (n == 0) Nil else List.tabulate(n)(i => tupleSel(binder)(i + 1))
private def genTake(binder: Symbol, n: Int): List[Tree] = if (n == 0) Nil else List.tabulate(n)(codegen index seqTree(binder, forceImmutable = false))
private def genDrop(binder: Symbol, n: Int): List[Tree] = codegen.drop(seqTree(binder, forceImmutable = false))(n) :: Nil

// codegen.drop(seqTree(binder))(nbIndexingIndices)))).toList
protected def seqTree(binder: Symbol) = tupleSel(binder)(firstIndexingBinder + 1)
protected def tupleSel(binder: Symbol)(i: Int): Tree = codegen.tupleSel(binder)(i)
protected def seqTree(binder: Symbol, forceImmutable: Boolean) = tupleSel(binder)(firstIndexingBinder + 1)
protected def tupleSel(binder: Symbol)(i: Int): Tree = codegen.tupleSel(binder)(i)

// the trees that select the subpatterns on the extractor's result,
// referenced by `binder`
protected def subPatRefsSeq(binder: Symbol): List[Tree] = {
def lastTrees: List[Tree] = (
def lastTrees: List[Tree] = {
if (!isStar) Nil
else if (expectedLength == 0) seqTree(binder) :: Nil
else if (expectedLength == 0) seqTree(binder, forceImmutable = true) :: Nil
else genDrop(binder, expectedLength)
)
}
// this error-condition has already been checked by checkStarPatOK:
// if(isSeq) assert(firstIndexingBinder + nbIndexingIndices + (if(lastIsStar) 1 else 0) == totalArity, "(resultInMonad, ts, subPatTypes, subPats)= "+(resultInMonad, ts, subPatTypes, subPats))

Expand All @@ -440,23 +440,25 @@ trait MatchTranslation {

// the trees that select the subpatterns on the extractor's result, referenced by `binder`
// require (nbSubPats > 0 && (!lastIsStar || isSeq))
protected def subPatRefs(binder: Symbol): List[Tree] = (
protected def subPatRefs(binder: Symbol): List[Tree] = {
if (totalArity > 0 && isSeq) subPatRefsSeq(binder)
else productElemsToN(binder, totalArity)
)
}

private def compareInts(t1: Tree, t2: Tree) =
gen.mkMethodCall(termMember(ScalaPackage, "math"), TermName("signum"), Nil, (t1 INT_- t2) :: Nil)

protected def lengthGuard(binder: Symbol): Option[Tree] =
// no need to check unless it's an unapplySeq and the minimal length is non-trivially satisfied
checkedLength map { expectedLength =>
def lengthCompareSym = binder.info member nme.lengthCompare

// `binder.lengthCompare(expectedLength)`
// ...if binder has a lengthCompare method, otherwise
// `scala.math.signum(binder.length - expectedLength)`
def checkExpectedLength = lengthCompareSym match {
case NoSymbol => compareInts(Select(seqTree(binder), nme.length), LIT(expectedLength))
case lencmp => (seqTree(binder) DOT lencmp)(LIT(expectedLength))
case NoSymbol => compareInts(Select(seqTree(binder, forceImmutable = false), nme.length), LIT(expectedLength))
case lencmp => (seqTree(binder, forceImmutable = false) DOT lencmp)(LIT(expectedLength))
}

// the comparison to perform
Expand All @@ -467,7 +469,7 @@ trait MatchTranslation {
else _ INT_== _

// `if (binder != null && $checkExpectedLength [== | >=] 0) then else zero`
(seqTree(binder) ANY_!= NULL) AND compareOp(checkExpectedLength, ZERO)
(seqTree(binder, forceImmutable = false) ANY_!= NULL) AND compareOp(checkExpectedLength, ZERO)
}

def checkedLength: Option[Int] =
Expand Down Expand Up @@ -569,9 +571,13 @@ trait MatchTranslation {
extractorTreeMaker :: Nil
}

override protected def seqTree(binder: Symbol): Tree =
if (firstIndexingBinder == 0) REF(binder)
else super.seqTree(binder)
override protected def seqTree(binder: Symbol, forceImmutable: Boolean): Tree =
if (firstIndexingBinder == 0) {
val ref = REF(binder)
if (forceImmutable && !binder.tpe.typeSymbol.isNonBottomSubClass(SeqClass)) Select(ref, nme.toSeq)
else ref
}
else super.seqTree(binder, forceImmutable)

// the trees that select the subpatterns on the extractor's result, referenced by `binder`
// require (totalArity > 0 && (!lastIsStar || isSeq))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,9 @@ trait PatternExpansion {
else tps.map(_.substSym(List(unapplySelector), List(extractedBinder)))

val withoutStar = productTypes ::: List.fill(elementArity)(elementType)
replaceUnapplySelector(if (isStar) withoutStar :+ sequenceType else withoutStar)
replaceUnapplySelector(if (isStar) withoutStar :+ seqType(elementType) else withoutStar)
}

def lengthCompareSym = sequenceType member nme.lengthCompare

// rest is private
private val isUnapply = fun.symbol.name == nme.unapply
private val isUnapplySeq = fun.symbol.name == nme.unapplySeq
Expand Down Expand Up @@ -190,23 +188,19 @@ trait PatternExpansion {
}
else equivConstrParamTypes

private def notRepeated = (NoType, NoType, NoType)
private val (elementType, sequenceType, repeatedType) =
private def notRepeated = (NoType, NoType)
private val (elementType, repeatedType) =
// case class C() is deprecated, but still need to defend against equivConstrParamTypes.isEmpty
if (isUnapply || equivConstrParamTypes.isEmpty) notRepeated
else {
val lastParamTp = equivConstrParamTypes.last
if (isUnapplySeq) {
val elementTp =
elementTypeFromHead(lastParamTp) orElse
elementTypeFromApply(lastParamTp) orElse
definitions.elementType(ArrayClass, lastParamTp)

(elementTp, lastParamTp, scalaRepeatedType(elementTp))
val elementTp = elementTypeFromApply(lastParamTp)
(elementTp, scalaRepeatedType(elementTp))
} else {
definitions.elementType(RepeatedParamClass, lastParamTp) match {
case NoType => notRepeated
case elementTp => (elementTp, seqType(elementTp), lastParamTp)
case elementTp => (elementTp, lastParamTp)
}
}
}
Expand All @@ -228,7 +222,7 @@ trait PatternExpansion {
}

private def arityError(mismatch: String) = {
val isErroneous = (productTypes contains NoType) && !(isSeq && (sequenceType ne NoType))
val isErroneous = (productTypes contains NoType) && !(isSeq && (elementType ne NoType))

val offeringString = if (isErroneous) "<error>" else productTypes match {
case tps if isSeq => (tps.map(_.toString) :+ s"${elementType}*").mkString("(", ", ", ")")
Expand Down
12 changes: 10 additions & 2 deletions src/library/scala/Array.scala
Original file line number Diff line number Diff line change
Expand Up @@ -520,8 +520,16 @@ object Array {
* @param x the selector value
* @return sequence wrapped in a [[scala.Some]], if `x` is an Array, otherwise `None`
*/
def unapplySeq[T](x: Array[T]): Option[IndexedSeq[T]] =
Some(ArraySeq.unsafeWrapArray[T](x))
def unapplySeq[T](x: Array[T]): UnapplySeqWrapper[T] = new UnapplySeqWrapper(x)

final class UnapplySeqWrapper[T](private val a: Array[T]) extends AnyVal {
def isEmpty: Boolean = false
def get: UnapplySeqWrapper[T] = this
def lengthCompare(len: Int): Int = a.lengthCompare(len)
def apply(i: Int): T = a(i)
def drop(n: Int): scala.Seq[T] = ArraySeq.unsafeWrapArray(a.drop(n)) // clones the array, also if n == 0
def toSeq: scala.Seq[T] = a.toSeq // clones the array
}
}

/** Arrays are mutable, indexed collections of values. `Array[T]` is Scala's representation
Expand Down
29 changes: 20 additions & 9 deletions src/library/scala/collection/Factory.scala
Original file line number Diff line number Diff line change
Expand Up @@ -282,20 +282,30 @@ object IterableFactory {
/**
* @tparam CC Collection type constructor (e.g. `List`)
*/
trait SeqFactory[+CC[_]] extends IterableFactory[CC] {
def unapplySeq[A](x: CC[A] @uncheckedVariance): Some[CC[A]] = Some(x) //TODO is uncheckedVariance sound here?
trait SeqFactory[+CC[A] <: SeqOps[A, Seq, Seq[A]]] extends IterableFactory[CC] {
import SeqFactory.UnapplySeqWrapper
final def unapplySeq[A](x: CC[A] @uncheckedVariance): UnapplySeqWrapper[A] = new UnapplySeqWrapper(x) // TODO is uncheckedVariance sound here?
}

object SeqFactory {
@SerialVersionUID(3L)
class Delegate[CC[_]](delegate: SeqFactory[CC]) extends SeqFactory[CC] {
class Delegate[CC[A] <: SeqOps[A, Seq, Seq[A]]](delegate: SeqFactory[CC]) extends SeqFactory[CC] {
def empty[A]: CC[A] = delegate.empty
def from[E](it: IterableOnce[E]): CC[E] = delegate.from(it)
def newBuilder[A]: Builder[A, CC[A]] = delegate.newBuilder[A]
}

final class UnapplySeqWrapper[A](private val c: SeqOps[A, Seq, Seq[A]]) extends AnyVal {
def isEmpty: Boolean = false
def get: UnapplySeqWrapper[A] = this
def lengthCompare(len: Int): Int = c.lengthCompare(len)
def apply(i: Int): A = c(i)
def drop(n: Int): scala.Seq[A] = c.view.drop(n).toSeq
def toSeq: scala.Seq[A] = c.toSeq
}
}

trait StrictOptimizedSeqFactory[+CC[_]] extends SeqFactory[CC] {
trait StrictOptimizedSeqFactory[+CC[A] <: SeqOps[A, Seq, Seq[A]]] extends SeqFactory[CC] {

override def fill[A](n: Int)(elem: => A): CC[A] = {
val b = newBuilder[A]
Expand Down Expand Up @@ -645,23 +655,24 @@ object ClassTagIterableFactory {
/**
* @tparam CC Collection type constructor (e.g. `ArraySeq`)
*/
trait ClassTagSeqFactory[+CC[_]] extends ClassTagIterableFactory[CC] {
def unapplySeq[A](x: CC[A] @uncheckedVariance): Some[CC[A]] = Some(x) //TODO is uncheckedVariance sound here?
trait ClassTagSeqFactory[+CC[A] <: SeqOps[A, Seq, Seq[A]]] extends ClassTagIterableFactory[CC] {
import SeqFactory.UnapplySeqWrapper
final def unapplySeq[A](x: CC[A] @uncheckedVariance): UnapplySeqWrapper[A] = new UnapplySeqWrapper(x) // TODO is uncheckedVariance sound here?
}

object ClassTagSeqFactory {
@SerialVersionUID(3L)
class Delegate[CC[_]](delegate: ClassTagSeqFactory[CC])
class Delegate[CC[A] <: SeqOps[A, Seq, Seq[A]]](delegate: ClassTagSeqFactory[CC])
extends ClassTagIterableFactory.Delegate[CC](delegate) with ClassTagSeqFactory[CC]

/** A SeqFactory that uses ClassTag.Any as the evidence for every element type. This may or may not be
* sound depending on the use of the `ClassTag` by the collection implementation. */
@SerialVersionUID(3L)
class AnySeqDelegate[CC[_]](delegate: ClassTagSeqFactory[CC])
class AnySeqDelegate[CC[A] <: SeqOps[A, Seq, Seq[A]]](delegate: ClassTagSeqFactory[CC])
extends ClassTagIterableFactory.AnyIterableDelegate[CC](delegate) with SeqFactory[CC]
}

trait StrictOptimizedClassTagSeqFactory[+CC[_]] extends ClassTagSeqFactory[CC] {
trait StrictOptimizedClassTagSeqFactory[+CC[A] <: SeqOps[A, Seq, Seq[A]]] extends ClassTagSeqFactory[CC] {

override def fill[A : ClassTag](n: Int)(elem: => A): CC[A] = {
val b = newBuilder[A]
Expand Down
3 changes: 0 additions & 3 deletions src/reflect/scala/reflect/internal/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -928,9 +928,7 @@ trait Definitions extends api.StandardDefinitions {
// For name-based pattern matching, derive the "element type" (type argument of Option/Seq)
// from the relevant part of the signature of various members (get/head/apply/drop)
def elementTypeFromGet(tp: Type) = typeArgOfBaseTypeOr(tp, OptionClass)(resultOfMatchingMethod(tp, nme.get)())
def elementTypeFromHead(tp: Type) = typeArgOfBaseTypeOr(tp, SeqClass)(resultOfMatchingMethod(tp, nme.head)())
def elementTypeFromApply(tp: Type) = typeArgOfBaseTypeOr(tp, SeqClass)(resultOfMatchingMethod(tp, nme.apply)(IntTpe))
def elementTypeFromDrop(tp: Type) = typeArgOfBaseTypeOr(tp, SeqClass)(resultOfMatchingMethod(tp, nme.drop)(IntTpe))
def resultOfIsEmpty(tp: Type) = resultOfMatchingMethod(tp, nme.isEmpty)()

// scala/bug#8128 Still using the type argument of the base type at Seq/Option if this is an old-style (2.10 compatible)
Expand Down Expand Up @@ -1544,7 +1542,6 @@ trait Definitions extends api.StandardDefinitions {
lazy val arrayCloneMethod = getMemberMethod(ScalaRunTimeModule, nme.array_clone)
lazy val ensureAccessibleMethod = getMemberMethod(ScalaRunTimeModule, nme.ensureAccessible)
lazy val arrayClassMethod = getMemberMethod(ScalaRunTimeModule, nme.arrayClass)
lazy val traversableDropMethod = getMemberMethod(ScalaRunTimeModule, nme.drop)
lazy val wrapVarargsRefArrayMethod = getMemberMethod(getWrapVarargsArrayModule, nme.wrapRefArray)

lazy val GroupOfSpecializable = getMemberClass(SpecializableModule, tpnme.Group)
Expand Down
1 change: 1 addition & 0 deletions src/reflect/scala/reflect/internal/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,7 @@ trait StdNames {
val toArray: NameType = "toArray"
val toList: NameType = "toList"
val toObjectArray : NameType = "toObjectArray"
val toSeq: NameType = "toSeq"
val toStats: NameType = "toStats"
val TopScope: NameType = "TopScope"
val toString_ : NameType = "toString"
Expand Down
15 changes: 15 additions & 0 deletions test/files/neg/patmat-seq-neg.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
patmat-seq-neg.scala:15: error: error during expansion of this match (this is a scalac bug).
The underlying error was: type mismatch;
found : scala.collection.mutable.ArrayBuffer[Int]
required: Seq[Int]
def t3: Any = 2 match {
^
patmat-seq-neg.scala:18: error: error during expansion of this match (this is a scalac bug).
The underlying error was: value toSeq is not a member of Array[Int]
def t4: Any = 2 match {
^
patmat-seq-neg.scala:24: error: error during expansion of this match (this is a scalac bug).
The underlying error was: value drop is not a member of Array[Int]
def t6: Any = 2 match {
^
three errors found
27 changes: 27 additions & 0 deletions test/files/neg/patmat-seq-neg.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
object A {
def unapplySeq(a: Int) = Some(collection.mutable.ArrayBuffer(1,2,3))
}
object B {
def unapplySeq(a: Int) = Some(Array(1,2,3))
}

class T {
def t1: Any = 2 match {
case A(xs@_*) => xs // ok
}
def t2: Any = 2 match {
case A(x, y) => (x, y) // ok
}
def t3: Any = 2 match {
case A(x, xs@_*) => (x, xs) // type error with call to drop. found: ArrayBuffer, required: Seq.
}
def t4: Any = 2 match {
case B(xs@_*) => xs // error: toSeq is not a member of Array. no ArrayOps because adaptToMember is disabled after typer.
}
def t5: Any = 2 match {
case B(x, y) => (x, y) // ok
}
def t6: Any = 2 match {
case B(x, xs@_*) => (x, xs) // error: drop is not a member of Array
}
}
2 changes: 1 addition & 1 deletion test/files/run/macro-expand-unapply-a/Impls_Macros_1.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import scala.reflect.macros.whitebox.Context

object Helper {
def unapplySeq[T](x: List[T]): Option[Seq[T]] = List.unapplySeq[T](x)
def unapplySeq[T](x: List[T]): Option[Seq[T]] = Some(x)
}

object Macros {
Expand Down
Loading

0 comments on commit 308ae2d

Please sign in to comment.