Skip to content

Commit

Permalink
Extractor pattern no longer matches null
Browse files Browse the repository at this point in the history
Until now, the spec did not say anything about `null` for extractor pattern, so that:
  - The pattern matcher would happily pass `null` into your extractor.
  - One could write a `null`-matching extractor `MyNull`.
  - But all extractor authors must consider `null` as a possible argument value.

No more! The pattern matcher inserts a non-`null` check before invoking an extractor,
so that you don't have to.

This is a general fix for scala/bug#2241, scala/bug#8787. See scala/bug#4364.
  • Loading branch information
eed3si9n authored and adriaanm committed Jun 1, 2018
1 parent f16eacc commit 1f0c8a7
Show file tree
Hide file tree
Showing 15 changed files with 202 additions and 78 deletions.
3 changes: 3 additions & 0 deletions spec/08-pattern-matching.md
Expand Up @@ -192,6 +192,9 @@ a case class, the stable identifier $x$ denotes an object which has a
member method named `unapply` or `unapplySeq` that matches
the pattern.

An extractor pattern cannot match the value `null`. The implementation
ensures that the `unapply`/`unapplySeq` method is not applied to `null`.

An `unapply` method in an object $x$ _matches_ the pattern
$x(p_1 , \ldots , p_n)$ if it takes exactly one argument and one of
the following applies:
Expand Down
Expand Up @@ -344,6 +344,7 @@ trait MatchApproximation extends TreeAndTypeAnalysis with ScalaLogic with MatchT
case AlternativesTreeMaker(_, altss, _) => \/(altss map (alts => /\(alts map this)))
case ProductExtractorTreeMaker(testedBinder, None) => uniqueNonNullProp(binderToUniqueTree(testedBinder))
case SubstOnlyTreeMaker(_, _) => True
case NonNullTestTreeMaker(prevBinder, _, _) => uniqueNonNullProp(binderToUniqueTree(prevBinder))
case GuardTreeMaker(guard) =>
guard.tpe match {
case ConstantTrue => True
Expand Down
Expand Up @@ -111,20 +111,28 @@ trait MatchTranslation {
val (makers, unappBinder) = {
val paramType = extractor.expectedExtractedType
// Statically conforms to paramType
if (tpe <:< paramType) (treeMakers(binder, false, pos), binder)
if (tpe <:< paramType) {
// enforce all extractor patterns to be non-null
val nonNullTest = NonNullTestTreeMaker(binder, paramType, pos)
val unappBinder = nonNullTest.nextBinder
(nonNullTest :: treeMakers(unappBinder, pos), unappBinder)
}
else {
// chain a type-testing extractor before the actual extractor call
// it tests the type, checks the outer pointer and casts to the expected type
// TODO: the outer check is mandated by the spec for case classes, but we do it for user-defined unapplies as well [SPEC]
// (the prefix of the argument passed to the unapply must equal the prefix of the type of the binder)
val typeTest = TypeTestTreeMaker(binder, binder, paramType, paramType)(pos, extractorArgTypeTest = true)
val binderKnownNonNull = typeTest impliesBinderNonNull binder

// check whether typetest implies binder is not null,
// even though the eventual null check will be on typeTest.nextBinder
// it'll be equal to binder casted to paramType anyway (and the type test is on binder)
val unappBinder = typeTest.nextBinder
(typeTest :: treeMakers(unappBinder, binderKnownNonNull, pos), unappBinder)
// skip null test if it's implied
if (binderKnownNonNull) {
val unappBinder = typeTest.nextBinder
(typeTest :: treeMakers(unappBinder, pos), unappBinder)
} else {
val nonNullTest = NonNullTestTreeMaker(typeTest.nextBinder, paramType, pos)
val unappBinder = nonNullTest.nextBinder
(typeTest :: nonNullTest :: treeMakers(unappBinder, pos), unappBinder)
}
}
}

Expand Down Expand Up @@ -380,11 +388,8 @@ trait MatchTranslation {

abstract class ExtractorCall(fun: Tree, args: List[Tree]) extends ExtractorAlignment(fun, args)(context) {
/** Create the TreeMaker that embodies this extractor call
*
* `binderKnownNonNull` indicates whether the cast implies `binder` cannot be null
* when `binderKnownNonNull` is `true`, `ProductExtractorTreeMaker` does not do a (redundant) null check on binder
*/
def treeMakers(binder: Symbol, binderKnownNonNull: Boolean, pos: Position): List[TreeMaker]
def treeMakers(binder: Symbol, pos: Position): List[TreeMaker]

// `subPatBinders` are the variables bound by this pattern in the following patterns
// subPatBinders are replaced by references to the relevant part of the extractor's result (tuple component, seq element, the result as-is)
Expand Down Expand Up @@ -480,10 +485,8 @@ trait MatchTranslation {
/** Create the TreeMaker that embodies this extractor call
*
* `binder` has been casted to `paramType` if necessary
* `binderKnownNonNull` indicates whether the cast implies `binder` cannot be null
* when `binderKnownNonNull` is `true`, `ProductExtractorTreeMaker` does not do a (redundant) null check on binder
*/
def treeMakers(binder: Symbol, binderKnownNonNull: Boolean, pos: Position): List[TreeMaker] = {
def treeMakers(binder: Symbol, pos: Position): List[TreeMaker] = {
val paramAccessors = expectedExtractedType.typeSymbol.constrParamAccessors
val numParams = paramAccessors.length
def paramAccessorAt(subPatIndex: Int) = paramAccessors(math.min(subPatIndex, numParams - 1))
Expand All @@ -504,7 +507,7 @@ trait MatchTranslation {
)

// checks binder ne null before chaining to the next extractor
ProductExtractorTreeMaker(binder, lengthGuard(binder))(subPatBinders, subPatRefs(binder), mutableBinders, binderKnownNonNull, ignoredSubPatBinders) :: Nil
ProductExtractorTreeMaker(binder, lengthGuard(binder))(subPatBinders, subPatRefs(binder), mutableBinders, ignoredSubPatBinders) :: Nil
}

// reference the (i-1)th case accessor if it exists, otherwise the (i-1)th tuple component
Expand All @@ -531,14 +534,13 @@ trait MatchTranslation {
/** Create the TreeMaker that embodies this extractor call
*
* `binder` has been casted to `paramType` if necessary
* `binderKnownNonNull` is not used in this subclass
*
* TODO: implement review feedback by @retronym:
* Passing the pair of values around suggests:
* case class Binder(sym: Symbol, knownNotNull: Boolean).
* Perhaps it hasn't reached critical mass, but it would already clean things up a touch.
*/
def treeMakers(patBinderOrCasted: Symbol, binderKnownNonNull: Boolean, pos: Position): List[TreeMaker] = {
def treeMakers(patBinderOrCasted: Symbol, pos: Position): List[TreeMaker] = {
// the extractor call (applied to the binder bound by the flatMap corresponding
// to the previous (i.e., enclosing/outer) pattern)
val (extractorApply, needsSubst) = spliceApply(pos, patBinderOrCasted)
Expand Down
Expand Up @@ -193,6 +193,32 @@ trait MatchTreeMaking extends MatchCodeGen with Debugging {
}
}

/**
* Make a TreeMaker that performs null check.
* This is called prior to extractor call.
*/
case class NonNullTestTreeMaker(
prevBinder: Symbol,
expectedTp: Type,
override val pos: Position) extends FunTreeMaker {
import CODE._
override lazy val nextBinder = prevBinder.asTerm // just passing through
val nextBinderTp = nextBinder.info.widen

val nullCheck = REF(prevBinder) OBJ_NE NULL
lazy val localSubstitution = Substitution(Nil, Nil)

def isExpectedPrimitiveType = isPrimitiveValueType(expectedTp)

def chainBefore(next: Tree)(casegen: Casegen): Tree =
atPos(pos) {
if (isExpectedPrimitiveType) next
else casegen.ifThenElseZero(nullCheck, next)
}

override def toString = s"NN(${prevBinder.name})"
}

/**
* Make a TreeMaker that will result in an extractor call specified by `extractor`
* the next TreeMaker (here, we don't know which it'll be) is chained after this one by flatMap'ing
Expand Down Expand Up @@ -268,7 +294,6 @@ trait MatchTreeMaking extends MatchCodeGen with Debugging {
val subPatBinders: List[Symbol],
val subPatRefs: List[Tree],
val mutableBinders: List[Symbol],
binderKnownNonNull: Boolean,
val ignoredSubPatBinders: Set[Symbol]
) extends FunTreeMaker with PreserveSubPatBinders {

Expand All @@ -279,13 +304,7 @@ trait MatchTreeMaking extends MatchCodeGen with Debugging {
def extraStoredBinders: Set[Symbol] = mutableBinders.toSet

def chainBefore(next: Tree)(casegen: Casegen): Tree = {
val nullCheck = REF(prevBinder) OBJ_NE NULL
val cond =
if (binderKnownNonNull) extraCond
else (extraCond map (nullCheck AND _)
orElse Some(nullCheck))

cond match {
extraCond match {
case Some(cond) =>
casegen.ifThenElseZero(cond, bindSubPats(substitution(next)))
case _ =>
Expand Down
6 changes: 1 addition & 5 deletions src/library/scala/Array.scala
Expand Up @@ -498,11 +498,7 @@ object Array {
* @return sequence wrapped in a [[scala.Some]], if `x` is an Array, otherwise `None`
*/
def unapplySeq[T](x: Array[T]): Option[IndexedSeq[T]] =
if (x == null) None else Some(ArraySeq.unsafeWrapArray[T](x))
// !!! the null check should to be necessary, but without it 2241 fails. Seems to be a bug
// in pattern matcher. @PP: I noted in #4364 I think the behavior is correct.
// Is ArraySeq safe here? In 2.12 we used to call .toIndexedSeq which copied the array
// instead of wrapping it in a ArraySeq but it appears unnecessary.
Some(ArraySeq.unsafeWrapArray[T](x))
}

/** Arrays are mutable, indexed collections of values. `Array[T]` is Scala's representation
Expand Down
8 changes: 3 additions & 5 deletions src/library/scala/util/matching/Regex.scala
Expand Up @@ -279,10 +279,8 @@ class Regex private[matching](val pattern: Pattern, groupNames: String*) extends
* @param s The string to match
* @return The matches
*/
def unapplySeq(s: CharSequence): Option[List[String]] = s match {
case null => None
case _ =>
val m = pattern matcher s
def unapplySeq(s: CharSequence): Option[List[String]] = {
val m = pattern matcher s
if (runMatcher(m)) Some(List.tabulate(m.groupCount) { i => m.group(i + 1) })
else None
}
Expand Down Expand Up @@ -336,7 +334,7 @@ class Regex private[matching](val pattern: Pattern, groupNames: String*) extends
* and the result of that match is used.
*/
def unapplySeq(m: Match): Option[List[String]] =
if (m == null || m.matched == null) None
if (m.matched == null) None
else if (m.matcher.pattern == this.pattern) Regex.extractGroupsFromMatch(m)
else unapplySeq(m.matched)

Expand Down
3 changes: 3 additions & 0 deletions test/files/run/name-based-patmat.check
@@ -1,3 +1,6 @@
name-based-patmat.scala:73: warning: unreachable code
case Foo(5, 10) => 4 // should warn unreachable
^
`catdog only` has 11 chars
`catdog only, no product` has 23 chars
catdog
Expand Down
14 changes: 7 additions & 7 deletions test/files/run/patmatnew.check
@@ -1,21 +1,21 @@
patmatnew.scala:351: warning: a pure expression does nothing in statement position
patmatnew.scala:352: warning: a pure expression does nothing in statement position
case 1 => "OK"
^
patmatnew.scala:352: warning: a pure expression does nothing in statement position
patmatnew.scala:353: warning: a pure expression does nothing in statement position
case 2 => assert(false); "KO"
^
patmatnew.scala:352: warning: multiline expressions might require enclosing parentheses; a value can be silently discarded when Unit is expected
patmatnew.scala:353: warning: multiline expressions might require enclosing parentheses; a value can be silently discarded when Unit is expected
case 2 => assert(false); "KO"
^
patmatnew.scala:353: warning: a pure expression does nothing in statement position
patmatnew.scala:354: warning: a pure expression does nothing in statement position
case 3 => assert(false); "KO"
^
patmatnew.scala:353: warning: multiline expressions might require enclosing parentheses; a value can be silently discarded when Unit is expected
patmatnew.scala:354: warning: multiline expressions might require enclosing parentheses; a value can be silently discarded when Unit is expected
case 3 => assert(false); "KO"
^
patmatnew.scala:670: warning: This catches all Throwables. If this is really intended, use `case e : Throwable` to clear this warning.
patmatnew.scala:671: warning: This catches all Throwables. If this is really intended, use `case e : Throwable` to clear this warning.
case e => {
^
patmatnew.scala:489: warning: unreachable code
patmatnew.scala:490: warning: unreachable code
case _ if false =>
^
73 changes: 73 additions & 0 deletions test/files/run/patmatnew.scala
Expand Up @@ -38,6 +38,7 @@ object Test {
Ticket346.run()
Ticket37.run()
Ticket44.run()
NullMatch.run()
}

def assertEquals(a: Any, b: Any): Unit = { assert(a == b) }
Expand Down Expand Up @@ -762,4 +763,76 @@ object Test {

} // end Ticket346

// scala/bug#4364
object NullMatch {
object XArray {
def unapplySeq[A](x: Array[A]): Option[IndexedSeq[A]] =
if (x eq null) sys.error("Unexpected null!")
else Some(x.toIndexedSeq)
}

object YArray {
def unapply(xs: Array[Int]): Boolean =
if (xs eq null) sys.error("Unexpected null!")
else true
}

object Animal {
def unapply(x: AnyRef): Option[AnyRef] =
if (x.toString == "Animal") Some(x)
else None
}

def nullMatch[A](xs: Array[A]): Boolean = xs match {
case Array(xs @_*) => false
case _ => true
}

def nullMatch2[A](xs: Array[A]): Boolean = xs match {
case XArray(xs @_*) => false
case _ => true
}

def nullMatch3[A](xs: Array[A]): Boolean = xs match {
case XArray(xs @_*) if 1 == 1 => false
case _ => true
}

def nullMatch4(xs: Array[Int]): Boolean = xs match {
case YArray() => false
case _ => true
}

def nullMatch5(x: AnyRef): Boolean = x match {
case Animal(x) => false
case _ => true
}

def t8787nullMatch() = {
val r = """\d+""".r
val s: String = null
val x = s match { case r() => 1 ; case _ => 2 }
2 == x
}

def t8787nullMatcher() = {
val r = """(\d+):(\d+)""".r
val s = "1:2 3:4 5:6"
val z = ((r findAllMatchIn s).toList :+ null) flatMap {
case r(x, y) => Some((x.toInt, y.toInt))
case _ => None
}
List((1,2),(3,4),(5,6)) == z
}

def run() {
assert(nullMatch(null))
assert(nullMatch2(null))
assert(nullMatch3(null))
assert(nullMatch4(null))
assert(nullMatch5(null))
assert(t8787nullMatch())
assert(t8787nullMatcher())
}
}
}

0 comments on commit 1f0c8a7

Please sign in to comment.