Skip to content

Commit

Permalink
SI-7897, SI-6675 improves name-based patmat
Browse files Browse the repository at this point in the history
This emerges from a recent attempt to eliminate pattern matcher
related duplication and to bake the scalac-independent logic
out of it. I had in mind something a lot cleaner, but it was
a whole lot of work to get it here and I can take it no further.

Key file to admire is PatternExpander.scala, which should
provide a basis for some separation of concerns.

The bugs addressed are a CCE involving Tuple1 and an imprecise
warning regarding multiple pattern crushing.

Editorial: auto-tupling unapply results was a terrible idea which
should never have escaped from the crib. It is tantamount to
purposely throwing type safety down the toilet in the very place
where people need type safety the most. See SI-6111 and SI-6675 for
some other comments.
  • Loading branch information
paulp committed Dec 16, 2013
1 parent dbe7a36 commit 11bfa25
Show file tree
Hide file tree
Showing 23 changed files with 723 additions and 303 deletions.
11 changes: 5 additions & 6 deletions src/compiler/scala/tools/nsc/transform/UnCurry.scala
Original file line number Diff line number Diff line change
Expand Up @@ -457,12 +457,11 @@ abstract class UnCurry extends InfoTransform
else
super.transform(tree)
case UnApply(fn, args) =>
val fn1 = transform(fn)
val args1 = transformTrees(fn.symbol.name match {
case nme.unapply => args
case nme.unapplySeq => transformArgs(tree.pos, fn.symbol, args, localTyper.expectedPatternTypes(fn, args))
case _ => sys.error("internal error: UnApply node has wrong symbol")
})
val fn1 = transform(fn)
val args1 = fn.symbol.name match {
case nme.unapplySeq => transformArgs(tree.pos, fn.symbol, args, patmat.alignPatterns(tree).expectedTypes)
case _ => args
}
treeCopy.UnApply(tree, fn1, args1)

case Apply(fn, args) =>
Expand Down
165 changes: 56 additions & 109 deletions src/compiler/scala/tools/nsc/transform/patmat/MatchTranslation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,30 @@ trait MatchTranslation {
trait MatchTranslator extends TreeMakers with TreeMakerWarnings {
import typer.context

/** A conservative approximation of which patterns do not discern anything.
* They are discarded during the translation.
*/
object WildcardPattern {
def unapply(pat: Tree): Boolean = pat match {
case Bind(nme.WILDCARD, WildcardPattern()) => true // don't skip when binding an interesting symbol!
case Star(WildcardPattern()) => true
case x: Ident => treeInfo.isVarPattern(x)
case Alternative(ps) => ps forall unapply
case EmptyTree => true
case _ => false
}
}

object PatternBoundToUnderscore {
def unapply(pat: Tree): Boolean = pat match {
case Bind(nme.WILDCARD, _) => true // don't skip when binding an interesting symbol!
case Ident(nme.WILDCARD) => true
case Alternative(ps) => ps forall unapply
case Typed(PatternBoundToUnderscore(), _) => true
case _ => false
}
}

object SymbolBound {
def unapply(tree: Tree): Option[(Symbol, Tree)] = tree match {
case Bind(_, expr) if hasSym(tree) => Some(tree.symbol -> expr)
Expand Down Expand Up @@ -86,9 +110,10 @@ trait MatchTranslation {

// example check: List[Int] <:< ::[Int]
private def extractorStep(): TranslationStep = {
import extractor.{ paramType, treeMaker }
if (!extractor.isTyped)
ErrorUtils.issueNormalTypeError(tree, "Could not typecheck extractor call: "+ extractor)(context)
def paramType = extractor.aligner.wholeType
import extractor.treeMaker
// if (!extractor.isTyped)
// ErrorUtils.issueNormalTypeError(tree, "Could not typecheck extractor call: "+ extractor)(context)

// chain a type-testing extractor before the actual extractor call
// it tests the type, checks the outer pointer and casts to the expected type
Expand Down Expand Up @@ -355,36 +380,20 @@ trait MatchTranslation {
object ExtractorCall {
// TODO: check unargs == args
def apply(tree: Tree): ExtractorCall = tree match {
case UnApply(unfun, args) => new ExtractorCallRegular(unfun, args) // extractor
case Apply(fun, args) => new ExtractorCallProd(fun, args) // case class
case UnApply(unfun, args) => new ExtractorCallRegular(alignPatterns(tree), unfun, args) // extractor
case Apply(fun, args) => new ExtractorCallProd(alignPatterns(tree), fun, args) // case class
}
}

abstract class ExtractorCall {
abstract class ExtractorCall(val aligner: PatternAligned) {
import aligner._
def fun: Tree
def args: List[Tree]

val nbSubPats = args.length
val starLength = if (hasStar) 1 else 0
val nonStarLength = args.length - starLength

// everything okay, captain?
def isTyped: Boolean
def isSeq: Boolean

private def hasStar = nbSubPats > 0 && isStar(args.last)
private def isNonEmptySeq = nbSubPats > 0 && isSeq

/** This is special cased so that a single pattern will accept any extractor
* result, even if it's a tuple (SI-6675)
*/
def isSingle = nbSubPats == 1 && !isSeq

// to which type should the previous binder be casted?
def paramType : Type

protected def rawSubPatTypes: List[Type]
protected def resultType: Type
// don't go looking for selectors if we only expect one pattern
def rawSubPatTypes = aligner.extractedTypes
def resultInMonad = if (isBool) UnitTpe else typeOfMemberNamedGet(resultType)
def resultType = fun.tpe.finalResultType

/** Create the TreeMaker that embodies this extractor call
*
Expand All @@ -407,24 +416,14 @@ trait MatchTranslation {
lazy val ignoredSubPatBinders: Set[Symbol] = subPatBinders zip args collect { case (b, PatternBoundToUnderscore()) => b } toSet

// do repeated-parameter expansion to match up with the expected number of arguments (in casu, subpatterns)
private def nonStarSubPatTypes = formalTypes(rawInit :+ repeatedType, nonStarLength)
private def nonStarSubPatTypes = aligner.typedNonStarPatterns map (_.tpe)

def subPatTypes: List[Type] = (
if (rawSubPatTypes.isEmpty || !isSeq) rawSubPatTypes
else if (hasStar) nonStarSubPatTypes :+ sequenceType
else nonStarSubPatTypes
)

private def rawGet = typeOfMemberNamedGetOrSelf(resultType)
private def rawInit = rawSubPatTypes dropRight 1
protected def sequenceType = typeOfLastSelectorOrSelf(rawGet)
protected def elementType = elementTypeOfLastSelectorOrSelf(rawGet)
protected def repeatedType = scalaRepeatedType(elementType)
def subPatTypes: List[Type] = typedPatterns map (_.tpe)

// rawSubPatTypes.last is the Seq, thus there are `rawSubPatTypes.length - 1` non-seq elements in the tuple
protected def firstIndexingBinder = rawSubPatTypes.length - 1
protected def lastIndexingBinder = nbSubPats - 1 - starLength
protected def expectedLength = lastIndexingBinder - firstIndexingBinder + 1
// there are `productArity` non-seq elements in the tuple.
protected def firstIndexingBinder = productArity
protected def expectedLength = elementArity
protected def lastIndexingBinder = totalArity - starArity - 1

private def productElemsToN(binder: Symbol, n: Int): List[Tree] = 1 to n map tupleSel(binder) toList
private def genTake(binder: Symbol, n: Int): List[Tree] = (0 until n).toList map (codegen index seqTree(binder))
Expand All @@ -438,12 +437,12 @@ trait MatchTranslation {
// referenced by `binder`
protected def subPatRefsSeq(binder: Symbol): List[Tree] = {
def lastTrees: List[Tree] = (
if (!hasStar) Nil
if (!aligner.isStar) Nil
else if (expectedLength == 0) seqTree(binder) :: Nil
else genDrop(binder, expectedLength)
)
// this error-condition has already been checked by checkStarPatOK:
// if(isSeq) assert(firstIndexingBinder + nbIndexingIndices + (if(lastIsStar) 1 else 0) == nbSubPats, "(resultInMonad, ts, subPatTypes, subPats)= "+(resultInMonad, ts, subPatTypes, subPats))
// if(isSeq) assert(firstIndexingBinder + nbIndexingIndices + (if(lastIsStar) 1 else 0) == totalArity, "(resultInMonad, ts, subPatTypes, subPats)= "+(resultInMonad, ts, subPatTypes, subPats))

// [1] there are `firstIndexingBinder` non-seq tuple elements preceding the Seq
// [2] then we have to index the binder that represents the sequence for the remaining subpatterns, except for...
Expand All @@ -457,8 +456,10 @@ trait MatchTranslation {

// the trees that select the subpatterns on the extractor's result, referenced by `binder`
// require (nbSubPats > 0 && (!lastIsStar || isSeq))
protected def subPatRefs(binder: Symbol): List[Tree] =
if (isNonEmptySeq) subPatRefsSeq(binder) else productElemsToN(binder, nbSubPats)
protected def subPatRefs(binder: Symbol): List[Tree] = (
if (totalArity > 0 && isSeq) subPatRefsSeq(binder)
else productElemsToN(binder, totalArity)
)

private def compareInts(t1: Tree, t2: Tree) =
gen.mkMethodCall(termMember(ScalaPackage, "math"), TermName("signum"), Nil, (t1 INT_- t2) :: Nil)
Expand All @@ -478,7 +479,7 @@ trait MatchTranslation {
// when the last subpattern is a wildcard-star the expectedLength is but a lower bound
// (otherwise equality is required)
def compareOp: (Tree, Tree) => Tree =
if (hasStar) _ INT_>= _
if (aligner.isStar) _ INT_>= _
else _ INT_== _

// `if (binder != null && $checkExpectedLength [== | >=] 0) then else zero`
Expand All @@ -487,26 +488,14 @@ trait MatchTranslation {

def checkedLength: Option[Int] =
// no need to check unless it's an unapplySeq and the minimal length is non-trivially satisfied
if (!isSeq || expectedLength < starLength) None
if (!isSeq || expectedLength < starArity) None
else Some(expectedLength)
}

// TODO: to be called when there's a def unapplyProd(x: T): U
// U must have N members _1,..., _N -- the _i are type checked, call their type Ti,
// for now only used for case classes -- pretending there's an unapplyProd that's the identity (and don't call it)
class ExtractorCallProd(val fun: Tree, val args: List[Tree]) extends ExtractorCall {
private def constructorTp = fun.tpe

def isTyped = fun.isTyped

// to which type should the previous binder be casted?
def paramType = constructorTp.finalResultType
def resultType = fun.tpe.finalResultType

def isSeq = isVarArgTypes(rawSubPatTypes)

protected def rawSubPatTypes = constructorTp.paramTypes

class ExtractorCallProd(aligner: PatternAligned, val fun: Tree, val args: List[Tree]) extends ExtractorCall(aligner) {
/** Create the TreeMaker that embodies this extractor call
*
* `binder` has been casted to `paramType` if necessary
Expand Down Expand Up @@ -535,20 +524,11 @@ trait MatchTranslation {
if (accessors isDefinedAt (i-1)) REF(binder) DOT accessors(i-1)
else codegen.tupleSel(binder)(i) // this won't type check for case classes, as they do not inherit ProductN
}

override def toString() = s"ExtractorCallProd($fun:${fun.tpe} / ${fun.symbol} / args=$args)"
}

class ExtractorCallRegular(extractorCallIncludingDummy: Tree, val args: List[Tree]) extends ExtractorCall {
class ExtractorCallRegular(aligner: PatternAligned, extractorCallIncludingDummy: Tree, val args: List[Tree]) extends ExtractorCall(aligner) {
val Unapplied(fun) = extractorCallIncludingDummy

def tpe = fun.tpe
def paramType = firstParamType(tpe)
def resultType = tpe.finalResultType
def isTyped = (tpe ne NoType) && fun.isTyped && (resultInMonad ne ErrorType)
def isSeq = fun.symbol.name == nme.unapplySeq
def isBool = resultType =:= BooleanTpe

/** Create the TreeMaker that embodies this extractor call
*
* `binder` has been casted to `paramType` if necessary
Expand All @@ -571,7 +551,7 @@ trait MatchTranslation {
ExtractorTreeMaker(extractorApply, lengthGuard(binder), binder)(
subPatBinders,
subPatRefs(binder),
isBool,
aligner.isBool,
checkedLength,
patBinderOrCasted,
ignoredSubPatBinders
Expand All @@ -583,9 +563,9 @@ trait MatchTranslation {
else super.seqTree(binder)

// the trees that select the subpatterns on the extractor's result, referenced by `binder`
// require (nbSubPats > 0 && (!lastIsStar || isSeq))
// require (totalArity > 0 && (!lastIsStar || isSeq))
override protected def subPatRefs(binder: Symbol): List[Tree] =
if (isSingle) REF(binder) :: Nil // special case for extractors
if (aligner.isSingle) REF(binder) :: Nil // special case for extractors
else super.subPatRefs(binder)

protected def spliceApply(binder: Symbol): Tree = {
Expand All @@ -606,40 +586,7 @@ trait MatchTranslation {
splice transform extractorCallIncludingDummy
}

// what's the extractor's result type in the monad? It is the type of its nullary member `get`.
protected lazy val resultInMonad: Type = if (isBool) UnitTpe else typeOfMemberNamedGet(resultType)

protected lazy val rawSubPatTypes = (
if (isBool) Nil
else if (isSingle) resultInMonad :: Nil // don't go looking for selectors if we only expect one pattern
else typesOfSelectorsOrSelf(resultInMonad)
)

override def toString() = s"ExtractorCallRegular($fun: $tpe / ${fun.symbol})"
}

/** A conservative approximation of which patterns do not discern anything.
* They are discarded during the translation.
*/
object WildcardPattern {
def unapply(pat: Tree): Boolean = pat match {
case Bind(nme.WILDCARD, WildcardPattern()) => true // don't skip when binding an interesting symbol!
case Star(WildcardPattern()) => true
case x: Ident => treeInfo.isVarPattern(x)
case Alternative(ps) => ps forall unapply
case EmptyTree => true
case _ => false
}
}

object PatternBoundToUnderscore {
def unapply(pat: Tree): Boolean = pat match {
case Bind(nme.WILDCARD, _) => true // don't skip when binding an interesting symbol!
case Ident(nme.WILDCARD) => true
case Alternative(ps) => ps forall unapply
case Typed(PatternBoundToUnderscore(), _) => true
case _ => false
}
override def rawSubPatTypes = aligner.extractor.varargsTypes
}
}
}
78 changes: 44 additions & 34 deletions src/compiler/scala/tools/nsc/transform/patmat/MatchTreeMaking.scala
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,10 @@ trait MatchTreeMaking extends MatchCodeGen with Debugging {
debug.patmat("TTTM"+((prevBinder, extractorArgTypeTest, testedBinder, expectedTp, nextBinderTp)))

lazy val outerTestNeeded = (
!((expectedTp.prefix eq NoPrefix) || expectedTp.prefix.typeSymbol.isPackageClass)
&& needsOuterTest(expectedTp, testedBinder.info, matchOwner))
(expectedTp.prefix ne NoPrefix)
&& !expectedTp.prefix.typeSymbol.isPackageClass
&& needsOuterTest(expectedTp, testedBinder.info, matchOwner)
)

// the logic to generate the run-time test that follows from the fact that
// a `prevBinder` is expected to have type `expectedTp`
Expand All @@ -406,44 +408,52 @@ trait MatchTreeMaking extends MatchCodeGen with Debugging {
def renderCondition(cs: TypeTestCondStrategy): cs.Result = {
import cs._

def default =
// do type test first to ensure we won't select outer on null
if (outerTestNeeded) and(typeTest(testedBinder, expectedTp), outerTest(testedBinder, expectedTp))
else typeTest(testedBinder, expectedTp)

// propagate expected type
def expTp(t: Tree): t.type = t setType expectedTp

def testedWide = testedBinder.info.widen
def expectedWide = expectedTp.widen
def isAnyRef = testedWide <:< AnyRefTpe
def isAsExpected = testedWide <:< expectedTp
def isExpectedPrimitiveType = isAsExpected && isPrimitiveValueType(expectedTp)
def isExpectedReferenceType = isAsExpected && (expectedTp <:< AnyRefTpe)
def mkNullTest = nonNullTest(testedBinder)
def mkOuterTest = outerTest(testedBinder, expectedTp)
def mkTypeTest = typeTest(testedBinder, expectedWide)

def mkEqualsTest(lhs: Tree): cs.Result = equalsTest(lhs, testedBinder)
def mkEqTest(lhs: Tree): cs.Result = eqTest(lhs, testedBinder)
def addOuterTest(res: cs.Result): cs.Result = if (outerTestNeeded) and(res, mkOuterTest) else res

// If we conform to expected primitive type:
// it cannot be null and cannot have an outer pointer. No further checking.
// If we conform to expected reference type:
// have to test outer and non-null
// If we do not conform to expected type:
// have to test type and outer (non-null is implied by successful type test)
def mkDefault = (
if (isExpectedPrimitiveType) tru
else addOuterTest(
if (isExpectedReferenceType) mkNullTest
else mkTypeTest
)
)

// true when called to type-test the argument to an extractor
// don't do any fancy equality checking, just test the type
if (extractorArgTypeTest) default
// TODO: verify that we don't need to special-case Array
// I think it's okay:
// - the isInstanceOf test includes a test for the element type
// - Scala's arrays are invariant (so we don't drop type tests unsoundly)
if (extractorArgTypeTest) mkDefault
else expectedTp match {
// TODO: [SPEC] the spec requires `eq` instead of `==` for singleton types
// this implies sym.isStable
case SingleType(_, sym) => and(equalsTest(gen.mkAttributedQualifier(expectedTp), testedBinder), typeTest(testedBinder, expectedTp.widen))
// must use == to support e.g. List() == Nil
case ThisType(sym) if sym.isModule => and(equalsTest(CODE.REF(sym), testedBinder), typeTest(testedBinder, expectedTp.widen))
case ConstantType(Constant(null)) if testedBinder.info.widen <:< AnyRefTpe
=> eqTest(expTp(CODE.NULL), testedBinder)
case ConstantType(const) => equalsTest(expTp(Literal(const)), testedBinder)
case ThisType(sym) => eqTest(expTp(This(sym)), testedBinder)

// TODO: verify that we don't need to special-case Array
// I think it's okay:
// - the isInstanceOf test includes a test for the element type
// - Scala's arrays are invariant (so we don't drop type tests unsoundly)
case _ if testedBinder.info.widen <:< expectedTp =>
// if the expected type is a primitive value type, it cannot be null and it cannot have an outer pointer
// since the types conform, no further checking is required
if (isPrimitiveValueType(expectedTp)) tru
// have to test outer and non-null only when it's a reference type
else if (expectedTp <:< AnyRefTpe) {
// do non-null check first to ensure we won't select outer on null
if (outerTestNeeded) and(nonNullTest(testedBinder), outerTest(testedBinder, expectedTp))
else nonNullTest(testedBinder)
} else default

case _ => default
// TODO: [SPEC] the spec requires `eq` instead of `==` for singleton types - this implies sym.isStable
case SingleType(_, sym) => and(mkEqualsTest(gen.mkAttributedQualifier(expectedTp)), mkTypeTest)
case ThisType(sym) if sym.isModule => and(mkEqualsTest(CODE.REF(sym)), mkTypeTest) // must use == to support e.g. List() == Nil
case ConstantType(Constant(null)) if isAnyRef => mkEqTest(expTp(CODE.NULL))
case ConstantType(const) => mkEqualsTest(expTp(Literal(const)))
case ThisType(sym) => mkEqTest(expTp(This(sym)))
case _ => mkDefault
}
}

Expand Down
Loading

0 comments on commit 11bfa25

Please sign in to comment.