Skip to content

Commit

Permalink
SI-5189 fixed: safe type infer for constr pattern
Browse files Browse the repository at this point in the history
several fixes to the standard library due to
 - the safer type checker this fix gives us (thus, some casts had to be inserted)
 - SI-5548
 - type inference gets a bit more complicated, it needs help (chainl1 in combinator.Parsers)

To deal with the type slack between actual (run-time) types and statically known
types, for each abstract type T, reflect its variance as a skolem that is
upper-bounded by T (covariant position), or lower-bounded by T (contravariant).

Consider the following example:

 class AbsWrapperCov[+A]
 case class Wrapper[B](x: Wrapped[B]) extends AbsWrapperCov[B]

 def unwrap[T](x: AbsWrapperCov[T]): Wrapped[T] = x match {
   case Wrapper(wrapped) =>
     // Wrapper's type parameter must not be assumed to be equal to T,
     // it's *upper-bounded* by it
     wrapped // : Wrapped[_ <: T]
 }

this  method should  type check  if and  only if  Wrapped is  covariant in  its type
parameter

before inferring Wrapper's type parameter B from x's type AbsWrapperCov[T], we must
take into account that x's actual type is:
AbsWrapperCov[Tactual] forSome {type Tactual <: T}
since AbsWrapperCov is covariant in A -- in other words, we must not assume we know
T exactly, all we know is its upper bound

since method application is the only way to generate this slack between run-time and
compile-time types (TODO: right!?), we can simply replace skolems that represent
method type parameters as seen from the method's body by other skolems that are
(upper/lower)-bounded by that type-parameter skolem (depending on the variance
position of the skolem in the statically assumed type of the scrutinee, pt)

this type slack is introduced by adaptConstrPattern: before it calls
inferConstructorInstance, it creates a new context that holds the new existential
skolems

the context created by adaptConstrPattern must not be a CaseDef, since that
confuses instantiateTypeVar and the whole pushTypeBounds/restoreTypeBounds dance
(CaseDef contexts remember the bounds of the type params that we clobbered
during GADT typing)

typedCase deskolemizes the existential skolems back to the method skolems,
since they don't serve any further purpose (except confusing the old pattern
matcher)

typedCase is now better at finding that context (using nextEnclosing)
  • Loading branch information
adriaanm committed Mar 9, 2012
1 parent 0cffdf3 commit 29bcade
Show file tree
Hide file tree
Showing 12 changed files with 180 additions and 31 deletions.
6 changes: 3 additions & 3 deletions src/compiler/scala/reflect/internal/Symbols.scala
Expand Up @@ -269,9 +269,9 @@ trait Symbols extends api.Symbols { self: SymbolTable =>
/** Create a new existential type skolem with this symbol its owner,
* based on the given symbol and origin.
*/
def newExistentialSkolem(basis: Symbol, origin: AnyRef): TypeSkolem = {
val skolem = newTypeSkolemSymbol(basis.name.toTypeName, origin, basis.pos, (basis.flags | EXISTENTIAL) & ~PARAM)
skolem setInfo (basis.info cloneInfo skolem)
def newExistentialSkolem(basis: Symbol, origin: AnyRef, name: TypeName = null, info: Type = null): TypeSkolem = {
val skolem = newTypeSkolemSymbol(if (name eq null) basis.name.toTypeName else name, origin, basis.pos, (basis.flags | EXISTENTIAL) & ~PARAM)
skolem setInfo (if (info eq null) basis.info cloneInfo skolem else info)
}

final def newExistential(name: TypeName, pos: Position = NoPosition, newFlags: Long = 0L): Symbol =
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/scala/tools/nsc/io/Pickler.scala
Expand Up @@ -165,7 +165,7 @@ object Pickler {
def pkl[T: Pickler] = implicitly[Pickler[T]]

/** A class represenenting `~`-pairs */
case class ~[S, T](fst: S, snd: T)
case class ~[+S, +T](fst: S, snd: T)

/** A wrapper class to be able to use `~` s an infix method */
class TildeDecorator[S](x: S) {
Expand Down
10 changes: 9 additions & 1 deletion src/compiler/scala/tools/nsc/typechecker/Infer.scala
Expand Up @@ -1090,7 +1090,15 @@ trait Infer {
inferFor(pt.instantiateTypeParams(ptparams, ptparams map (x => WildcardType))) flatMap { targs =>
val ctorTpInst = tree.tpe.instantiateTypeParams(undetparams, targs)
val resTpInst = skipImplicit(ctorTpInst.finalResultType)
val ptvars = ptparams map freshVar
val ptvars =
ptparams map {
// since instantiateTypeVar wants to modify the skolem that corresponds to the method's type parameter,
// and it uses the TypeVar's origin to locate it, deskolemize the existential skolem to the method tparam skolem
// (the existential skolem was created by adaptConstrPattern to introduce the type slack necessary to soundly deal with variant type parameters)
case skolem if skolem.isExistentialSkolem => freshVar(skolem.deSkolemize.asInstanceOf[TypeSymbol])
case p => freshVar(p)
}

val ptV = pt.instantiateTypeParams(ptparams, ptvars)

if (isPopulated(resTpInst, ptV)) {
Expand Down
86 changes: 79 additions & 7 deletions src/compiler/scala/tools/nsc/typechecker/Typers.scala
Expand Up @@ -852,6 +852,33 @@ trait Typers extends Modes with Adaptations with PatMatVirtualiser {
}
}

/**
* To deal with the type slack between actual (run-time) types and statically known types, for each abstract type T,
* reflect its variance as a skolem that is upper-bounded by T (covariant position), or lower-bounded by T (contravariant).
*
* Consider the following example:
*
* class AbsWrapperCov[+A]
* case class Wrapper[B](x: Wrapped[B]) extends AbsWrapperCov[B]
*
* def unwrap[T](x: AbsWrapperCov[T]): Wrapped[T] = x match {
* case Wrapper(wrapped) => // Wrapper's type parameter must not be assumed to be equal to T, it's *upper-bounded* by it
* wrapped // : Wrapped[_ <: T]
* }
*
* this method should type check if and only if Wrapped is covariant in its type parameter
*
* when inferring Wrapper's type parameter B from x's type AbsWrapperCov[T],
* we must take into account that x's actual type is AbsWrapperCov[Tactual] forSome {type Tactual <: T}
* as AbsWrapperCov is covariant in A -- in other words, we must not assume we know T exactly, all we know is its upper bound
*
* since method application is the only way to generate this slack between run-time and compile-time types (TODO: right!?),
* we can simply replace skolems that represent method type parameters as seen from the method's body
* by other skolems that are (upper/lower)-bounded by that type-parameter skolem
* (depending on the variance position of the skolem in the statically assumed type of the scrutinee, pt)
*
* see test/files/../t5189*.scala
*/
def adaptConstrPattern(): Tree = { // (5)
val extractor = tree.symbol.filter(sym => reallyExists(unapplyMember(sym.tpe)))
if (extractor != NoSymbol) {
Expand All @@ -865,7 +892,32 @@ trait Typers extends Modes with Adaptations with PatMatVirtualiser {
val tree1 = TypeTree(clazz.primaryConstructor.tpe.asSeenFrom(prefix, clazz.owner))
.setOriginal(tree)

inferConstructorInstance(tree1, clazz.typeParams, pt)
val skolems = new mutable.ListBuffer[TypeSymbol]
object variantToSkolem extends VariantTypeMap {
def apply(tp: Type) = mapOver(tp) match {
case TypeRef(NoPrefix, tpSym, Nil) if variance != 0 && tpSym.isTypeParameterOrSkolem && tpSym.owner.isTerm =>
val bounds = if (variance == 1) TypeBounds.upper(tpSym.tpe) else TypeBounds.lower(tpSym.tpe)
val skolem = context.owner.newExistentialSkolem(tpSym, tpSym, unit.freshTypeName("?"+tpSym.name), bounds)
// println("mapping "+ tpSym +" to "+ skolem + " : "+ bounds +" -- pt= "+ pt)
skolems += skolem
skolem.tpe
case tp1 => tp1
}
}

// have to open up the existential and put the skolems in scope
// can't simply package up pt in an ExistentialType, because that takes us back to square one (List[_ <: T] == List[T] due to covariance)
val ptSafe = variantToSkolem(pt) // TODO: pt.skolemizeExistential(context.owner, tree) ?
val freeVars = skolems.toList

// use "tree" for the context, not context.tree: don't make another CaseDef context,
// as instantiateTypeVar's bounds would end up there
val ctorContext = context.makeNewScope(tree, context.owner)
freeVars foreach ctorContext.scope.enter
newTyper(ctorContext).infer.inferConstructorInstance(tree1, clazz.typeParams, ptSafe)

// tree1's type-slack skolems will be deskolemized (to the method type parameter skolems)
// once the containing CaseDef has been type checked (see typedCase)
tree1
} else {
tree
Expand Down Expand Up @@ -1986,15 +2038,35 @@ trait Typers extends Modes with Adaptations with PatMatVirtualiser {
val guard1: Tree = if (cdef.guard == EmptyTree) EmptyTree
else typed(cdef.guard, BooleanClass.tpe)
var body1: Tree = typed(cdef.body, pt)
if (!context.savedTypeBounds.isEmpty) {
body1.tpe = context.restoreTypeBounds(body1.tpe)
if (isFullyDefined(pt) && !(body1.tpe <:< pt)) {
// @M no need for pt.normalize here, is done in erasure

val contextWithTypeBounds = context.nextEnclosing(_.tree.isInstanceOf[CaseDef])
if (contextWithTypeBounds.savedTypeBounds nonEmpty) {
body1.tpe = contextWithTypeBounds restoreTypeBounds body1.tpe

// insert a cast if something typechecked under the GADT constraints,
// but not in real life (i.e., now that's we've reset the method's type skolems'
// infos back to their pre-GADT-constraint state)
if (isFullyDefined(pt) && !(body1.tpe <:< pt))
body1 = typedPos(body1.pos)(gen.mkCast(body1, pt))
}

}

// body1 = checkNoEscaping.locals(context.scope, pt, body1)
treeCopy.CaseDef(cdef, pat1, guard1, body1) setType body1.tpe
val treeWithSkolems = treeCopy.CaseDef(cdef, pat1, guard1, body1) setType body1.tpe

// undo adaptConstrPattern's evil deeds, as they confuse the old pattern matcher
// TODO: Paul, can we do the deskolemization lazily in the old pattern matcher
object deskolemizeOnce extends TypeMap {
def apply(tp: Type): Type = mapOver(tp) match {
case TypeRef(pre, sym, args) if sym.isExistentialSkolem && sym.deSkolemize.isSkolem && sym.deSkolemize.owner.isTerm =>
typeRef(NoPrefix, sym.deSkolemize, args)
case tp1 => tp1
}
}

new TypeMapTreeSubstituter(deskolemizeOnce).traverse(treeWithSkolems)

treeWithSkolems // now without skolems, actually
}

def typedCases(cases: List[CaseDef], pattp: Type, pt: Type): List[CaseDef] =
Expand Down
12 changes: 6 additions & 6 deletions src/library/scala/collection/JavaConversions.scala
Expand Up @@ -69,7 +69,7 @@ object JavaConversions {
* @return A Java Iterator view of the argument.
*/
implicit def asJavaIterator[A](it: Iterator[A]): ju.Iterator[A] = it match {
case JIteratorWrapper(wrapped) => wrapped
case JIteratorWrapper(wrapped) => wrapped.asInstanceOf[ju.Iterator[A]]
case _ => IteratorWrapper(it)
}

Expand All @@ -87,7 +87,7 @@ object JavaConversions {
* @return A Java Enumeration view of the argument.
*/
implicit def asJavaEnumeration[A](it: Iterator[A]): ju.Enumeration[A] = it match {
case JEnumerationWrapper(wrapped) => wrapped
case JEnumerationWrapper(wrapped) => wrapped.asInstanceOf[ju.Enumeration[A]]
case _ => IteratorWrapper(it)
}

Expand All @@ -105,7 +105,7 @@ object JavaConversions {
* @return A Java Iterable view of the argument.
*/
implicit def asJavaIterable[A](i: Iterable[A]): jl.Iterable[A] = i match {
case JIterableWrapper(wrapped) => wrapped
case JIterableWrapper(wrapped) => wrapped.asInstanceOf[jl.Iterable[A]]
case _ => IterableWrapper(i)
}

Expand All @@ -121,7 +121,7 @@ object JavaConversions {
* @return A Java Collection view of the argument.
*/
implicit def asJavaCollection[A](it: Iterable[A]): ju.Collection[A] = it match {
case JCollectionWrapper(wrapped) => wrapped
case JCollectionWrapper(wrapped) => wrapped.asInstanceOf[ju.Collection[A]]
case _ => new IterableWrapper(it)
}

Expand Down Expand Up @@ -179,7 +179,7 @@ object JavaConversions {
* @return A Java List view of the argument.
*/
implicit def seqAsJavaList[A](seq: Seq[A]): ju.List[A] = seq match {
case JListWrapper(wrapped) => wrapped
case JListWrapper(wrapped) => wrapped.asInstanceOf[ju.List[A]]
case _ => new SeqWrapper(seq)
}

Expand Down Expand Up @@ -286,7 +286,7 @@ object JavaConversions {
*/
implicit def mapAsJavaMap[A, B](m: Map[A, B]): ju.Map[A, B] = m match {
//case JConcurrentMapWrapper(wrapped) => wrapped
case JMapWrapper(wrapped) => wrapped
case JMapWrapper(wrapped) => wrapped.asInstanceOf[ju.Map[A, B]]
case _ => new MapWrapper(m)
}

Expand Down
8 changes: 4 additions & 4 deletions src/library/scala/collection/immutable/IntMap.scala
Expand Up @@ -353,19 +353,19 @@ extends AbstractMap[Int, T]
def unionWith[S >: T](that : IntMap[S], f : (Int, S, S) => S) : IntMap[S] = (this, that) match{
case (IntMap.Bin(p1, m1, l1, r1), that@(IntMap.Bin(p2, m2, l2, r2))) =>
if (shorter(m1, m2)) {
if (!hasMatch(p2, p1, m1)) join(p1, this, p2, that);
if (!hasMatch(p2, p1, m1)) join[S](p1, this, p2, that); // TODO: remove [S] when SI-5548 is fixed
else if (zero(p2, m1)) IntMap.Bin(p1, m1, l1.unionWith(that, f), r1);
else IntMap.Bin(p1, m1, l1, r1.unionWith(that, f));
} else if (shorter(m2, m1)){
if (!hasMatch(p1, p2, m2)) join(p1, this, p2, that);
if (!hasMatch(p1, p2, m2)) join[S](p1, this, p2, that); // TODO: remove [S] when SI-5548 is fixed
else if (zero(p1, m2)) IntMap.Bin(p2, m2, this.unionWith(l2, f), r2);
else IntMap.Bin(p2, m2, l2, this.unionWith(r2, f));
}
else {
if (p1 == p2) IntMap.Bin(p1, m1, l1.unionWith(l2,f), r1.unionWith(r2, f));
else join(p1, this, p2, that);
else join[S](p1, this, p2, that); // TODO: remove [S] when SI-5548 is fixed
}
case (IntMap.Tip(key, value), x) => x.updateWith(key, value, (x, y) => f(key, y, x));
case (IntMap.Tip(key, value), x) => x.updateWith[S](key, value, (x, y) => f(key, y, x));
case (x, IntMap.Tip(key, value)) => x.updateWith[S](key, value, (x, y) => f(key, x, y));
case (IntMap.Nil, x) => x;
case (x, IntMap.Nil) => x;
Expand Down
8 changes: 4 additions & 4 deletions src/library/scala/collection/immutable/LongMap.scala
Expand Up @@ -349,19 +349,19 @@ extends AbstractMap[Long, T]
def unionWith[S >: T](that : LongMap[S], f : (Long, S, S) => S) : LongMap[S] = (this, that) match{
case (LongMap.Bin(p1, m1, l1, r1), that@(LongMap.Bin(p2, m2, l2, r2))) =>
if (shorter(m1, m2)) {
if (!hasMatch(p2, p1, m1)) join(p1, this, p2, that);
if (!hasMatch(p2, p1, m1)) join[S](p1, this, p2, that); // TODO: remove [S] when SI-5548 is fixed
else if (zero(p2, m1)) LongMap.Bin(p1, m1, l1.unionWith(that, f), r1);
else LongMap.Bin(p1, m1, l1, r1.unionWith(that, f));
} else if (shorter(m2, m1)){
if (!hasMatch(p1, p2, m2)) join(p1, this, p2, that);
if (!hasMatch(p1, p2, m2)) join[S](p1, this, p2, that); // TODO: remove [S] when SI-5548 is fixed
else if (zero(p1, m2)) LongMap.Bin(p2, m2, this.unionWith(l2, f), r2);
else LongMap.Bin(p2, m2, l2, this.unionWith(r2, f));
}
else {
if (p1 == p2) LongMap.Bin(p1, m1, l1.unionWith(l2,f), r1.unionWith(r2, f));
else join(p1, this, p2, that);
else join[S](p1, this, p2, that); // TODO: remove [S] when SI-5548 is fixed
}
case (LongMap.Tip(key, value), x) => x.updateWith(key, value, (x, y) => f(key, y, x));
case (LongMap.Tip(key, value), x) => x.updateWith[S](key, value, (x, y) => f(key, y, x)); // TODO: remove [S] when SI-5548 is fixed
case (x, LongMap.Tip(key, value)) => x.updateWith[S](key, value, (x, y) => f(key, x, y));
case (LongMap.Nil, x) => x;
case (x, LongMap.Nil) => x;
Expand Down
5 changes: 2 additions & 3 deletions src/library/scala/util/parsing/combinator/Parsers.scala
Expand Up @@ -794,7 +794,7 @@ trait Parsers {
*/
def chainl1[T, U](first: => Parser[T], p: => Parser[U], q: => Parser[(T, U) => T]): Parser[T]
= first ~ rep(q ~ p) ^^ {
case x ~ xs => xs.foldLeft(x){(_, _) match {case (a, f ~ b) => f(a, b)}}
case x ~ xs => xs.foldLeft(x: T){case (a, f ~ b) => f(a, b)} // x's type annotation is needed to deal with changed type inference due to SI-5189
}

/** A parser generator that generalises the `rep1sep` generator so that `q`,
Expand All @@ -812,8 +812,7 @@ trait Parsers {
*/
def chainr1[T, U](p: => Parser[T], q: => Parser[(T, U) => U], combine: (T, U) => U, first: U): Parser[U]
= p ~ rep(q ~ p) ^^ {
case x ~ xs => (new ~(combine, x) :: xs).
foldRight(first){(_, _) match {case (f ~ a, b) => f(a, b)}}
case x ~ xs => (new ~(combine, x) :: xs).foldRight(first){case (f ~ a, b) => f(a, b)}
}

/** A parser generator for optional sub-phrases.
Expand Down
2 changes: 1 addition & 1 deletion src/scalap/scala/tools/scalap/scalax/rules/Rules.scala
Expand Up @@ -130,7 +130,7 @@ trait StateRules {
def rep(in : S, t : T) : Result[S, T, X] = {
if (finished(t)) Success(in, t)
else rule(in) match {
case Success(out, f) => rep(out, f(t))
case Success(out, f) => rep(out, f(t)) // SI-5189 f.asInstanceOf[T => T]
case Failure => Failure
case Error(x) => Error(x)
}
Expand Down
2 changes: 1 addition & 1 deletion test/files/jvm/typerep.scala
Expand Up @@ -161,7 +161,7 @@ object TypeRep {
}).asInstanceOf[TypeRep[Option[A]]]

def getType[A](x: List[A])(implicit rep: TypeRep[A]): TypeRep[List[A]] = (x match {
case h :: t => ListRep(getType(h))
case h :: t => ListRep(rep)
case Nil => NilRep
}).asInstanceOf[TypeRep[List[A]]]

Expand Down
8 changes: 8 additions & 0 deletions test/files/neg/t5189b.check
@@ -0,0 +1,8 @@
t5189b.scala:25: error: type mismatch;
found : TestNeg.Wrapped[?T2] where type ?T2 <: T
required: TestNeg.Wrapped[T]
Note: ?T2 <: T, but class Wrapped is invariant in type W.
You may wish to define W as +W instead. (SLS 4.5)
case Wrapper/*[_ <: T ]*/(wrapped) => wrapped // : Wrapped[_ <: T], which is a subtype of Wrapped[T] if and only if Wrapped is covariant in its type parameter
^
one error found
62 changes: 62 additions & 0 deletions test/files/neg/t5189b.scala
@@ -0,0 +1,62 @@
class TestPos {
class AbsWrapperCov[+A]
case class Wrapper[B](x: B) extends AbsWrapperCov[B]

def unwrap[T](x: AbsWrapperCov[T]): T = x match {
case Wrapper/*[_ <: T ]*/(x) => x // _ <: T, which is a subtype of T
}
}

object TestNeg extends App {
class AbsWrapperCov[+A]
case class Wrapper[B](x: Wrapped[B]) extends AbsWrapperCov[B]

/*
when inferring Wrapper's type parameter B from x's type AbsWrapperCov[T],
we must take into account that x's actual type is AbsWrapperCov[Tactual] forSome {type Tactual <: T}
as AbsWrapperCov is covariant in A -- in other words, we must not assume we know T exactly, all we know is its upper bound
since method application is the only way to generate this slack between run-time and compile-time types,
we'll simply replace the skolems that represent method type parameters as seen from the method's body by
other skolems that are (upper/lower)-bounded by the type-parameter skolems
(depending on whether the skolem appears in a covariant/contravariant position)
*/
def unwrap[T](x: AbsWrapperCov[T]): Wrapped[T] = x match {
case Wrapper/*[_ <: T ]*/(wrapped) => wrapped // : Wrapped[_ <: T], which is a subtype of Wrapped[T] if and only if Wrapped is covariant in its type parameter
}

class Wrapped[W](var cell: W) // must be invariant (to trigger the bug)

// class A { def imNotAB = println("notB")}
// class B
//
// val w = new Wrapped(new A)
// unwrap[Any](Wrapper(w)).cell = new B
// w.cell.imNotAB
}

// class TestPos1 {
// class Base[T]
// case class C[T](x: T) extends Base[T]
// def foo[T](b: Base[T]): T = b match { case C(x) => x }
//
// case class Span[K <: Ordered[K]](low: Option[K], high: Option[K]) extends Function1[K, Boolean] {
// override def equals(x$1: Any): Boolean = x$1 match {
// case Span((low$0 @ _), (high$0 @ _)) if low$0.equals(low).$amp$amp(high$0.equals(high)) => true
// case _ => false
// }
// def apply(k: K): Boolean = this match {
// case Span(Some(low), Some(high)) => (k >= low && k <= high)
// case Span(Some(low), None) => (k >= low)
// case Span(None, Some(high)) => (k <= high)
// case _ => false
// }
// }
// }
//
// class TestNeg1 {
// case class Foo[T, U](f: T => U)
// def f(x: Any): Any => Any = x match { case Foo(bar) => bar }
// // uh-oh, Any => Any should be Nothing => Any.
// }

0 comments on commit 29bcade

Please sign in to comment.