Skip to content

Commit b29c01b

Browse files
Aleksandar Prokopecaxel22
authored andcommitted
Fix SI-5284.
The problem was the false assumption that methods specialized on their type parameter, such as this one: class Foo[@SPEC(Int) T](val x: T) { def bar[@SPEC(Int) S >: T](f: S => S) = f(x) } have their normalized versions (`bar$mIc$sp`) never called from the base specialization class `Foo`. This meant that the implementation of `bar$mIc$sp` in `Foo` simply threw an exception. This assumption is not true, however. See this: object Baz { def apply[T]() = new Foo[T] } Calling `Baz.apply[Int]()` will create an instance of the base specialization class `Foo` at `Int`. Calling `bar` on this instance will be rewritten by specialization to calling `bar$mIc$sp`, hence the error. So, we have to emit a valid implementation for `bar`, obviously. Problem is, such an implementation would have conflicting type bounds in the base specialization class `Foo`, since we don't know if `T` is a subtype of `S = Int`. In other words, we cannot emit: def bar$mIc$sp(f: Int => Int) = f(x) // x: T without typechecking errors. However, notice that the bounds are valid if and only if `T = Int`. In the same time, invocations of `bar$mIc$sp` will only be emitted in callsites where the type bounds hold. This means we can cast the expressions in method applications to the required specialized type bound. The following changes have been made: 1) The decision of whether or not to create a normalized version of the specialized method is not done on the `conflicting` relation anymore. Instead, it's done based on the `satisfiable` relation, which is true if there is possibly an instantiation of the type parameters where the bounds hold. 2) The `satisfiable` method has a new variant called `satisfiableConstraints`, which does unification to figure out how the type parameters should be instantiated in order to satisfy the bounds. 3) The `Duplicators` are changed to transform a tree using the `castType` method which just returns the tree by default. In specialization, the `castType` in `Duplicators` is overridden, and uses a map from type parameters to types. This map is obtained by `satisfiableConstraints` from 2). If the type of the expression is not equal to the expected type, and this map contains a mapping to the expected type, then the tree is cast, as discussed above. Additional tests added. Review by @dragos Review by @VladUreche Conflicts: src/compiler/scala/tools/nsc/transform/SpecializeTypes.scala src/compiler/scala/tools/nsc/typechecker/Duplicators.scala
1 parent 16f5350 commit b29c01b

File tree

10 files changed

+180
-26
lines changed

10 files changed

+180
-26
lines changed

src/compiler/scala/tools/nsc/transform/Constructors.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ abstract class Constructors extends Transform with ast.TreeDSL {
323323
// statements coming from the original class need retyping in the current context
324324
debuglog("retyping " + stat2)
325325

326-
val d = new specializeTypes.Duplicator
326+
val d = new specializeTypes.Duplicator(Map[Symbol, Type]())
327327
d.retyped(localTyper.context1.asInstanceOf[d.Context],
328328
stat2,
329329
genericClazz,

src/compiler/scala/tools/nsc/transform/SpecializeTypes.scala

Lines changed: 72 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,12 @@ abstract class SpecializeTypes extends InfoTransform with TypingTransformers {
450450

451451
/** Type parameters that survive when specializing in the specified environment. */
452452
def survivingParams(params: List[Symbol], env: TypeEnv) =
453-
params.filter(p => !p.isSpecialized || !isPrimitiveValueType(env(p)))
453+
params filter {
454+
p =>
455+
!p.isSpecialized ||
456+
!env.contains(p) ||
457+
!isPrimitiveValueType(env(p))
458+
}
454459

455460
/** Produces the symbols from type parameters `syms` of the original owner,
456461
* in the given type environment `env`. The new owner is `nowner`.
@@ -1176,7 +1181,7 @@ abstract class SpecializeTypes extends InfoTransform with TypingTransformers {
11761181
|| specializedTypeVars(t1).nonEmpty
11771182
|| specializedTypeVars(t2).nonEmpty)
11781183
}
1179-
1184+
11801185
env forall { case (tvar, tpe) =>
11811186
matches(tvar.info.bounds.lo, tpe) && matches(tpe, tvar.info.bounds.hi) || {
11821187
if (warnings)
@@ -1192,10 +1197,58 @@ abstract class SpecializeTypes extends InfoTransform with TypingTransformers {
11921197
}
11931198
}
11941199
}
1200+
1201+
def satisfiabilityConstraints(env: TypeEnv): Option[TypeEnv] = {
1202+
val noconstraints = Some(emptyEnv)
1203+
def matches(tpe1: Type, tpe2: Type): Option[TypeEnv] = {
1204+
val t1 = subst(env, tpe1)
1205+
val t2 = subst(env, tpe2)
1206+
// log("---------> " + tpe1 + " matches " + tpe2)
1207+
// log(t1 + ", " + specializedTypeVars(t1))
1208+
// log(t2 + ", " + specializedTypeVars(t2))
1209+
// log("unify: " + unify(t1, t2, env, false, false) + " in " + env)
1210+
if (t1 <:< t2) noconstraints
1211+
else if (specializedTypeVars(t1).nonEmpty) Some(unify(t1, t2, env, false, false) -- env.keys)
1212+
else if (specializedTypeVars(t2).nonEmpty) Some(unify(t2, t1, env, false, false) -- env.keys)
1213+
else None
1214+
}
1215+
1216+
env.foldLeft[Option[TypeEnv]](noconstraints) {
1217+
case (constraints, (tvar, tpe)) =>
1218+
val loconstraints = matches(tvar.info.bounds.lo, tpe)
1219+
val hiconstraints = matches(tpe, tvar.info.bounds.hi)
1220+
val allconstraints = for (c <- constraints; l <- loconstraints; h <- hiconstraints) yield c ++ l ++ h
1221+
allconstraints
1222+
}
1223+
}
11951224

1196-
class Duplicator extends {
1225+
/** This duplicator additionally performs casts of expressions if that is allowed by the `casts` map. */
1226+
class Duplicator(casts: Map[Symbol, Type]) extends {
11971227
val global: SpecializeTypes.this.global.type = SpecializeTypes.this.global
1198-
} with typechecker.Duplicators
1228+
} with typechecker.Duplicators {
1229+
private val (castfrom, castto) = casts.unzip
1230+
private object CastMap extends SubstTypeMap(castfrom.toList, castto.toList)
1231+
1232+
class BodyDuplicator(_context: Context) extends super.BodyDuplicator(_context) {
1233+
override def castType(tree: Tree, pt: Type): Tree = {
1234+
// log(" expected type: " + pt)
1235+
// log(" tree type: " + tree.tpe)
1236+
tree.tpe = if (tree.tpe != null) fixType(tree.tpe) else null
1237+
// log(" tree type: " + tree.tpe)
1238+
val ntree = if (tree.tpe != null && !(tree.tpe <:< pt)) {
1239+
val casttpe = CastMap(tree.tpe)
1240+
if (casttpe <:< pt) gen.mkCast(tree, casttpe)
1241+
else if (casttpe <:< CastMap(pt)) gen.mkCast(tree, pt)
1242+
else tree
1243+
} else tree
1244+
ntree.tpe = null
1245+
ntree
1246+
}
1247+
}
1248+
1249+
protected override def newBodyDuplicator(context: Context) = new BodyDuplicator(context)
1250+
1251+
}
11991252

12001253
/** A tree symbol substituter that substitutes on type skolems.
12011254
* If a type parameter is a skolem, it looks for the original
@@ -1475,14 +1528,14 @@ abstract class SpecializeTypes extends InfoTransform with TypingTransformers {
14751528
deriveDefDef(tree1)(transform)
14761529

14771530
case NormalizedMember(target) =>
1478-
debuglog("Normalized member: " + symbol + ", target: " + target)
1479-
if (target.isDeferred || conflicting(typeEnv(symbol))) {
1531+
val constraints = satisfiabilityConstraints(typeEnv(symbol))
1532+
log("constraints: " + constraints)
1533+
if (target.isDeferred || constraints == None) {
14801534
deriveDefDef(tree)(_ => localTyper typed gen.mkSysErrorCall("Fatal error in code generation: this should never be called."))
1481-
}
1482-
else {
1535+
} else {
14831536
// we have an rhs, specialize it
14841537
val tree1 = reportTypeError {
1485-
duplicateBody(ddef, target)
1538+
duplicateBody(ddef, target, constraints.get)
14861539
}
14871540
debuglog("implementation: " + tree1)
14881541
deriveDefDef(tree1)(transform)
@@ -1546,7 +1599,7 @@ abstract class SpecializeTypes extends InfoTransform with TypingTransformers {
15461599
val tree1 = deriveValDef(tree)(_ => body(symbol.alias).duplicate)
15471600
debuglog("now typing: " + tree1 + " in " + tree.symbol.owner.fullName)
15481601

1549-
val d = new Duplicator
1602+
val d = new Duplicator(emptyEnv)
15501603
val newValDef = d.retyped(
15511604
localTyper.context1.asInstanceOf[d.Context],
15521605
tree1,
@@ -1571,12 +1624,18 @@ abstract class SpecializeTypes extends InfoTransform with TypingTransformers {
15711624
super.transform(tree)
15721625
}
15731626
}
1574-
1575-
private def duplicateBody(tree: DefDef, source: Symbol) = {
1627+
1628+
/** Duplicate the body of the given method `tree` to the new symbol `source`.
1629+
*
1630+
* Knowing that the method can be invoked only in the `castmap` type environment,
1631+
* this method will insert casts for all the expressions of types mappend in the
1632+
* `castmap`.
1633+
*/
1634+
private def duplicateBody(tree: DefDef, source: Symbol, castmap: TypeEnv = emptyEnv) = {
15761635
val symbol = tree.symbol
15771636
val meth = addBody(tree, source)
15781637

1579-
val d = new Duplicator
1638+
val d = new Duplicator(castmap)
15801639
debuglog("-->d DUPLICATING: " + meth)
15811640
d.retyped(
15821641
localTyper.context1.asInstanceOf[d.Context],

src/compiler/scala/tools/nsc/typechecker/Duplicators.scala

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ abstract class Duplicators extends Analyzer {
2121

2222
def retyped(context: Context, tree: Tree): Tree = {
2323
resetClassOwners
24-
(new BodyDuplicator(context)).typed(tree)
24+
(newBodyDuplicator(context)).typed(tree)
2525
}
2626

2727
/** Retype the given tree in the given context. Use this method when retyping
@@ -37,15 +37,17 @@ abstract class Duplicators extends Analyzer {
3737

3838
envSubstitution = new SubstSkolemsTypeMap(env.keysIterator.toList, env.valuesIterator.toList)
3939
debuglog("retyped with env: " + env)
40-
(new BodyDuplicator(context)).typed(tree)
40+
newBodyDuplicator(context).typed(tree)
4141
}
4242

43+
protected def newBodyDuplicator(context: Context) = new BodyDuplicator(context)
44+
4345
def retypedMethod(context: Context, tree: Tree, oldThis: Symbol, newThis: Symbol): Tree =
44-
(new BodyDuplicator(context)).retypedMethod(tree.asInstanceOf[DefDef], oldThis, newThis)
46+
(newBodyDuplicator(context)).retypedMethod(tree.asInstanceOf[DefDef], oldThis, newThis)
4547

4648
/** Return the special typer for duplicate method bodies. */
4749
override def newTyper(context: Context): Typer =
48-
new BodyDuplicator(context)
50+
newBodyDuplicator(context)
4951

5052
private def resetClassOwners() {
5153
oldClassOwner = null
@@ -209,6 +211,11 @@ abstract class Duplicators extends Analyzer {
209211
}
210212
}
211213

214+
/** Optionally cast this tree into some other type, if required.
215+
* Unless overridden, just returns the tree.
216+
*/
217+
def castType(tree: Tree, pt: Type): Tree = tree
218+
212219
/** Special typer method for re-type checking trees. It expects a typed tree.
213220
* Returns a typed tree that has fresh symbols for all definitions in the original tree.
214221
*
@@ -319,10 +326,10 @@ abstract class Duplicators extends Analyzer {
319326
super.typed(atPos(tree.pos)(tree1), mode, pt)
320327

321328
case This(_) =>
322-
// log("selection on this, plain: " + tree)
329+
debuglog("selection on this, plain: " + tree)
323330
tree.symbol = updateSym(tree.symbol)
324-
tree.tpe = null
325-
val tree1 = super.typed(tree, mode, pt)
331+
val ntree = castType(tree, pt)
332+
val tree1 = super.typed(ntree, mode, pt)
326333
// log("plain this typed to: " + tree1)
327334
tree1
328335
/* no longer needed, because Super now contains a This(...)
@@ -358,16 +365,18 @@ abstract class Duplicators extends Analyzer {
358365
case EmptyTree =>
359366
// no need to do anything, in particular, don't set the type to null, EmptyTree.tpe_= asserts
360367
tree
361-
368+
362369
case _ =>
363-
// log("Duplicators default case: " + tree.summaryString + " -> " + tree)
370+
debuglog("Duplicators default case: " + tree.summaryString)
371+
debuglog(" ---> " + tree)
364372
if (tree.hasSymbol && tree.symbol != NoSymbol && (tree.symbol.owner == definitions.AnyClass)) {
365373
tree.symbol = NoSymbol // maybe we can find a more specific member in a subclass of Any (see AnyVal members, like ==)
366374
}
367-
tree.tpe = null
368-
super.typed(tree, mode, pt)
375+
val ntree = castType(tree, pt)
376+
super.typed(ntree, mode, pt)
369377
}
370378
}
379+
371380
}
372381
}
373382

test/files/pos/spec-params-new.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@ class Foo[@specialized A: ClassTag] {
3131
val xs = new Array[A](1)
3232
xs(0) = x
3333
}
34-
}
34+
}

test/files/run/t5284.check

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
2

test/files/run/t5284.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
2+
3+
4+
5+
6+
/** Here we have a situation where a normalized method parameter `W`
7+
* is used in a position which accepts an instance of type `T` - we know we can
8+
* safely cast `T` to `W` whenever type bounds on `W` hold.
9+
*/
10+
object Test {
11+
def main(args: Array[String]) {
12+
val a = Blarg(Array(1, 2, 3))
13+
println(a.m((x: Int) => x + 1))
14+
}
15+
}
16+
17+
18+
object Blarg {
19+
def apply[T: Manifest](a: Array[T]) = new Blarg(a)
20+
}
21+
22+
23+
class Blarg[@specialized(Int) T: Manifest](val a: Array[T]) {
24+
def m[@specialized(Int) W >: T, @specialized(Int) S](f: W => S) = f(a(0))
25+
}

test/files/run/t5284b.check

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
17

test/files/run/t5284b.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
2+
3+
4+
5+
6+
7+
/** Here we have a situation where a normalized method parameter `W`
8+
* is used in a position which expects a type `T` - we know we can
9+
* safely cast `W` to `T` whenever typebounds of `W` hold.
10+
*/
11+
object Test {
12+
def main(args: Array[String]) {
13+
val foo = Foo.createUnspecialized[Int]
14+
println(foo.bar(17))
15+
}
16+
}
17+
18+
19+
object Foo {
20+
def createUnspecialized[T] = new Foo[T]
21+
}
22+
23+
24+
class Foo[@specialized(Int) T] {
25+
val id: T => T = x => x
26+
27+
def bar[@specialized(Int) W <: T, @specialized(Int) S](w: W) = id(w)
28+
}

test/files/run/t5284c.check

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3

test/files/run/t5284c.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
2+
3+
4+
5+
6+
7+
/** Here we have a compound type `List[W]` used in
8+
* a position where `List[T]` is expected. The cast
9+
* emitted in the normalized `bar` is safe because the
10+
* normalized `bar` can only be called if the type
11+
* bounds hold.
12+
*/
13+
object Test {
14+
def main(args: Array[String]) {
15+
val foo = Foo.createUnspecialized[Int]
16+
println(foo.bar(List(1, 2, 3)))
17+
}
18+
}
19+
20+
21+
object Foo {
22+
def createUnspecialized[T] = new Foo[T]
23+
}
24+
25+
26+
class Foo[@specialized(Int) T] {
27+
val len: List[T] => Int = xs => xs.length
28+
29+
def bar[@specialized(Int) W <: T](ws: List[W]) = len(ws)
30+
}

0 commit comments

Comments
 (0)