Permalink
Browse files

Eliminated VariantTypeMap.

Made variance tracking a constructor parameter to TypeMap.
Eliminated more variance duplication.
  • Loading branch information...
1 parent 9be6d05 commit 882f8e640b034cc69b122ac221f75cbe0018e2c3 @paulp paulp committed Jan 1, 2013
@@ -1459,7 +1459,7 @@ abstract class RefChecks extends InfoTransform with scala.reflect.internal.trans
hidden = o.isTerm || o.isPrivateLocal
o = o.owner
}
- if (!hidden) varianceValidator.escapedPrivateLocals += sym
+ if (!hidden) varianceValidator.escapedLocals += sym
}
def checkSuper(mix: Name) =
@@ -945,7 +945,7 @@ trait Typers extends Modes with Adaptations with Tags {
.setOriginal(tree)
val skolems = new mutable.ListBuffer[TypeSymbol]
- object variantToSkolem extends VariantTypeMap {
+ object variantToSkolem extends TypeMap(isTrackingVariance = true) {
def apply(tp: Type) = mapOver(tp) match {
// !!! FIXME - skipping this when variance.isInvariant allows unsoundness, see SI-5189
case TypeRef(NoPrefix, tpSym, Nil) if !variance.isInvariant && tpSym.isTypeParameterOrSkolem && tpSym.owner.isTerm =>
@@ -3932,18 +3932,25 @@ trait Types extends api.Types { self: SymbolTable =>
/** A prototype for mapping a function over all possible types
*/
- abstract class TypeMap extends (Type => Type) {
+ abstract class TypeMap(isTrackingVariance: Boolean) extends (Type => Type) {
+ def this() = this(isTrackingVariance = false)
def apply(tp: Type): Type
- /** Mix in VariantTypeMap if you want variances to be significant.
- */
- def variance: Variance = Invariant
+ private[this] var _variance: Variance = if (isTrackingVariance) Covariant else Invariant
+
+ def variance_=(x: Variance) = { assert(isTrackingVariance, this) ; _variance = x }
+ def variance = _variance
/** Map this function over given type */
def mapOver(tp: Type): Type = tp match {
case tr @ TypeRef(pre, sym, args) =>
val pre1 = this(pre)
- val args1 = args mapConserve this
+ val args1 = (
+ if (isTrackingVariance && args.nonEmpty && !variance.isInvariant && sym.typeParams.nonEmpty)
+ mapOverArgs(args, sym.typeParams)
+ else
+ args mapConserve this
+ )
if ((pre1 eq pre) && (args1 eq args)) tp
else copyTypeRef(tp, pre1, tr.coevolveSym(pre1), args1)
case ThisType(_) => tp
@@ -3955,12 +3962,12 @@ trait Types extends api.Types { self: SymbolTable =>
else singleType(pre1, sym)
}
case MethodType(params, result) =>
- val params1 = mapOver(params)
+ val params1 = flipped(mapOver(params))
val result1 = this(result)
if ((params1 eq params) && (result1 eq result)) tp
else copyMethodType(tp, params1, result1.substSym(params, params1))
case PolyType(tparams, result) =>
- val tparams1 = mapOver(tparams)
+ val tparams1 = flipped(mapOver(tparams))
val result1 = this(result)
if ((tparams1 eq tparams) && (result1 eq result)) tp
else PolyType(tparams1, result1.substSym(tparams, tparams1))
@@ -3975,7 +3982,7 @@ trait Types extends api.Types { self: SymbolTable =>
if ((thistp1 eq thistp) && (supertp1 eq supertp)) tp
else SuperType(thistp1, supertp1)
case TypeBounds(lo, hi) =>
- val lo1 = this(lo)
+ val lo1 = flipped(this(lo))
val hi1 = this(hi)
if ((lo1 eq lo) && (hi1 eq hi)) tp
else TypeBounds(lo1, hi1)
@@ -3998,7 +4005,7 @@ trait Types extends api.Types { self: SymbolTable =>
else OverloadedType(pre1, alts)
case AntiPolyType(pre, args) =>
val pre1 = this(pre)
- val args1 = args mapConserve (this)
+ val args1 = args mapConserve this
if ((pre1 eq pre) && (args1 eq args)) tp
else AntiPolyType(pre1, args1)
case tv@TypeVar(_, constr) =>
@@ -4026,16 +4033,41 @@ trait Types extends api.Types { self: SymbolTable =>
// throw new Error("mapOver inapplicable for " + tp);
}
- protected def mapOverArgs(args: List[Type], tparams: List[Symbol]): List[Type] =
- args mapConserve this
+ private def flip() = if (isTrackingVariance) variance = variance.flip
+ @inline final def flipped[T](body: => T): T = {
+ flip()
+ try body finally flip()
+ }
+ @inline final def varyOn(tparam: Symbol)(body: => Type): Type = {
+ val saved = variance
+ variance *= tparam.variance
+ try body finally variance = saved
+ }
+ protected def mapOverArgs(args: List[Type], tparams: List[Symbol]): List[Type] = (
+ if (isTrackingVariance)
+ map2Conserve(args, tparams)((arg, tparam) => varyOn(tparam)(this(arg)))
+ else
+ args mapConserve this
+ )
+ private def isInfoUnchanged(sym: Symbol) = {
+ val forceInvariance = isTrackingVariance && !variance.isInvariant && sym.isAliasType
+ val result = if (forceInvariance) {
+ val saved = variance
+ variance = Invariant
+ try this(sym.info) finally variance = saved
+ }
+ else this(sym.info)
+
+ (sym.info eq result)
+ }
/** Called by mapOver to determine whether the original symbols can
* be returned, or whether they must be cloned. Overridden in VariantTypeMap.
*/
protected def noChangeToSymbols(origSyms: List[Symbol]): Boolean = {
@tailrec def loop(syms: List[Symbol]): Boolean = syms match {
case Nil => true
- case x :: xs => (x.info eq this(x.info)) && loop(xs)
+ case x :: xs => isInfoUnchanged(x) && loop(xs)
}
loop(origSyms)
}
@@ -4189,7 +4221,7 @@ trait Types extends api.Types { self: SymbolTable =>
/** Used by existentialAbstraction.
*/
- class ExistentialExtrapolation(tparams: List[Symbol]) extends VariantTypeMap {
+ class ExistentialExtrapolation(tparams: List[Symbol]) extends TypeMap(isTrackingVariance = true) {
private val occurCount = mutable.HashMap[Symbol, Int]()
private def countOccs(tp: Type) = {
tp foreach {
@@ -4208,16 +4240,29 @@ trait Types extends api.Types { self: SymbolTable =>
apply(tpe)
}
+ /** If these conditions all hold:
+ * 1) we are in covariant (or contravariant) position
+ * 2) this type occurs exactly once in the existential scope
+ * 3) the widened upper (or lower) bound of this type contains no references to tparams
+ * Then we replace this lone occurrence of the type with the widened upper (or lower) bound.
+ * All other types pass through unchanged.
+ */
def apply(tp: Type): Type = {
val tp1 = mapOver(tp)
if (variance.isInvariant) tp1
else tp1 match {
case TypeRef(pre, sym, args) if tparams contains sym =>
val repl = if (variance.isPositive) dropSingletonType(tp1.bounds.hi) else tp1.bounds.lo
- //println("eliminate "+sym+"/"+repl+"/"+occurCount(sym)+"/"+(tparams exists (repl.contains)))//DEBUG
- if (!repl.typeSymbol.isBottomClass && occurCount(sym) == 1 && !(tparams exists (repl.contains)))
- repl
- else tp1
+ val count = occurCount(sym)
+ val containsTypeParam = tparams exists (repl contains _)
+ def msg = {
+ val word = if (variance.isPositive) "upper" else "lower"
+ s"Widened lone occurrence of $tp1 inside existential to $word bound"
+ }
+ if (!repl.typeSymbol.isBottomClass && count == 1 && !containsTypeParam)
+ logResult(msg)(repl)
+ else
+ tp1
case _ =>
tp1
}
@@ -5954,7 +5999,7 @@ trait Types extends api.Types { self: SymbolTable =>
* `f` maps all elements to themselves.
*/
def map2Conserve[A <: AnyRef, B](xs: List[A], ys: List[B])(f: (A, B) => A): List[A] =
- if (xs.isEmpty) xs
+ if (xs.isEmpty || ys.isEmpty) xs
else {
val x1 = f(xs.head, ys.head)
val xs1 = map2Conserve(xs.tail, ys.tail)(f)
@@ -74,6 +74,11 @@ final class Variance private (val flags: Int) extends AnyVal {
}
object Variance {
+ implicit class SbtCompat(val v: Variance) {
+ def < (other: Int) = v.flags < other
+ def > (other: Int) = v.flags > other
+ }
+
def fold(variances: List[Variance]): Variance = (
if (variances.isEmpty) Bivariant
else variances reduceLeft (_ & _)
@@ -14,82 +14,11 @@ import scala.collection.{ mutable, immutable }
trait Variances {
self: SymbolTable =>
- /** Used by ExistentialExtrapolation and adaptConstrPattern().
- * TODO - eliminate duplication with all the rest.
- */
- trait VariantTypeMap extends TypeMap {
- private[this] var _variance: Variance = Covariant
-
- override def variance = _variance
- def variance_=(x: Variance) = _variance = x
-
- override protected def noChangeToSymbols(origSyms: List[Symbol]) =
- //OPT inline from forall to save on #closures
- origSyms match {
- case sym :: rest =>
- val v = variance
- if (sym.isAliasType) variance = Invariant
- val result = this(sym.info)
- variance = v
- (result eq sym.info) && noChangeToSymbols(rest)
- case _ =>
- true
- }
-
- override protected def mapOverArgs(args: List[Type], tparams: List[Symbol]): List[Type] =
- map2Conserve(args, tparams) { (arg, tparam) =>
- val saved = variance
- variance *= tparam.variance
- try this(arg) finally variance = saved
- }
-
- /** Map this function over given type */
- override def mapOver(tp: Type): Type = tp match {
- case MethodType(params, result) =>
- variance = variance.flip
- val params1 = mapOver(params)
- variance = variance.flip
- val result1 = this(result)
- if ((params1 eq params) && (result1 eq result)) tp
- else copyMethodType(tp, params1, result1.substSym(params, params1))
- case PolyType(tparams, result) =>
- variance = variance.flip
- val tparams1 = mapOver(tparams)
- variance = variance.flip
- val result1 = this(result)
- if ((tparams1 eq tparams) && (result1 eq result)) tp
- else PolyType(tparams1, result1.substSym(tparams, tparams1))
- case TypeBounds(lo, hi) =>
- variance = variance.flip
- val lo1 = this(lo)
- variance = variance.flip
- val hi1 = this(hi)
- if ((lo1 eq lo) && (hi1 eq hi)) tp
- else TypeBounds(lo1, hi1)
- case tr @ TypeRef(pre, sym, args) =>
- val pre1 = this(pre)
- val args1 =
- if (args.isEmpty)
- args
- else if (variance.isInvariant) // fast & safe path: don't need to look at typeparams
- args mapConserve this
- else {
- val tparams = sym.typeParams
- if (tparams.isEmpty) args
- else mapOverArgs(args, tparams)
- }
- if ((pre1 eq pre) && (args1 eq args)) tp
- else copyTypeRef(tp, pre1, tr.coevolveSym(pre1), args1)
- case _ =>
- super.mapOver(tp)
- }
- }
-
/** Used in Refchecks.
* TODO - eliminate duplication with varianceInType
*/
class VarianceValidator extends Traverser {
- val escapedPrivateLocals = mutable.HashSet[Symbol]()
+ val escapedLocals = mutable.HashSet[Symbol]()
protected def issueVarianceError(base: Symbol, sym: Symbol, required: Variance): Unit = ()
@@ -104,10 +33,10 @@ trait Variances {
)
// return Bivariant if `sym` is local to a term
// or is private[this] or protected[this]
- private def isLocalOnly(sym: Symbol) = !sym.owner.isClass || (
+ def isLocalOnly(sym: Symbol) = !sym.owner.isClass || (
sym.isTerm
&& (sym.isPrivateLocal || sym.isProtectedLocal || sym.isSuperAccessor) // super accessors are implicitly local #4345
- && !escapedPrivateLocals(sym)
+ && !escapedLocals(sym)
)
/** Validate variance of info of symbol `base` */
@@ -126,20 +55,18 @@ trait Variances {
* because there may be references to the type parameter that are not checked.
*/
def relativeVariance(tvar: Symbol): Variance = {
- val clazz = tvar.owner
- var sym = base
- var state: Variance = Covariant
- while (sym != clazz && !state.isBivariant) {
- if (isFlipped(sym, tvar))
- state = state.flip
- else if (isLocalOnly(sym))
- state = Bivariant
- else if (sym.isAliasType)
- state = if (sym.isOverridingSymbol) Invariant else Bivariant
-
- sym = sym.owner
- }
- state
+ def nextVariance(sym: Symbol, v: Variance): Variance = (
+ if (isFlipped(sym, tvar)) v.flip
+ else if (isLocalOnly(sym)) Bivariant
+ else if (!sym.isAliasType) v
+ else if (sym.isOverridingSymbol) Invariant
+ else Bivariant
+ )
+ def loop(sym: Symbol, v: Variance): Variance = (
+ if (sym == tvar.owner || v.isBivariant) v
+ else loop(sym.owner, nextVariance(sym, v))
+ )
+ loop(base, Covariant)
}
/** Validate that the type `tp` is variance-correct, assuming
@@ -156,15 +83,17 @@ trait Variances {
case ConstantType(_) =>
case SingleType(pre, sym) =>
validateVariance(pre, variance)
+ case TypeRef(_, sym, _) if sym.isAliasType => validateVariance(tp.normalize, variance)
+
case TypeRef(pre, sym, args) =>
-// println("validate "+sym+" at "+relativeVariance(sym))
- if (sym.isAliasType/* && relativeVariance(sym) == Bivariant*/)
- validateVariance(tp.normalize, variance)
- else if (!sym.variance.isInvariant) {
- val v = relativeVariance(sym)
- val requiredVariance = v * variance
- if (!v.isBivariant && sym.variance != requiredVariance)
- issueVarianceError(base, sym, requiredVariance)
+ if (!sym.variance.isInvariant) {
+ val relative = relativeVariance(sym)
+ val required = relative * variance
+ if (!relative.isBivariant) {
+ log(s"verifying $sym (${sym.variance}${sym.locationString}) is $required at $base in ${base.owner}")
+ if (sym.variance != required)
+ issueVarianceError(base, sym, required)
+ }
}
validateVariance(pre, variance)
// @M for higher-kinded typeref, args.isEmpty
@@ -213,18 +142,18 @@ trait Variances {
}
override def traverse(tree: Tree) {
- // def local = tree.symbol.hasLocalFlag
+ def local = tree.symbol.hasLocalFlag
tree match {
case ClassDef(_, _, _, _) | TypeDef(_, _, _, _) =>
validateVariance(tree.symbol)
super.traverse(tree)
// ModuleDefs need not be considered because they have been eliminated already
case ValDef(_, _, _, _) =>
- if (!tree.symbol.hasLocalFlag)
+ if (!local)
validateVariance(tree.symbol)
case DefDef(_, _, tparams, vparamss, _, _) =>
// No variance check for object-private/protected methods/values.
- if (!tree.symbol.hasLocalFlag) {
+ if (!local) {
validateVariance(tree.symbol)
traverseTrees(tparams)
traverseTreess(vparamss)

0 comments on commit 882f8e6

Please sign in to comment.