Skip to content

Commit

Permalink
Merge pull request #4101 from adriaanm/sam-ex
Browse files Browse the repository at this point in the history
[sammy] eta-expansion, overloading, existentials
  • Loading branch information
lrytz committed Nov 10, 2014
2 parents 02c0852 + cbca494 commit 5a7875f
Show file tree
Hide file tree
Showing 20 changed files with 183 additions and 29 deletions.
6 changes: 6 additions & 0 deletions src/compiler/scala/tools/nsc/typechecker/Infer.scala
Expand Up @@ -295,11 +295,17 @@ trait Infer extends Checkable {
&& !isByNameParamType(tp)
&& isCompatible(tp, dropByName(pt))
)
def isCompatibleSam(tp: Type, pt: Type): Boolean = {
val samFun = typer.samToFunctionType(pt)
(samFun ne NoType) && isCompatible(tp, samFun)
}

val tp1 = normalize(tp)

( (tp1 weak_<:< pt)
|| isCoercible(tp1, pt)
|| isCompatibleByName(tp, pt)
|| isCompatibleSam(tp, pt)
)
}
def isCompatibleArgs(tps: List[Type], pts: List[Type]) = (tps corresponds pts)(isCompatible)
Expand Down
76 changes: 62 additions & 14 deletions src/compiler/scala/tools/nsc/typechecker/Typers.scala
Expand Up @@ -741,6 +741,26 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
case _ =>
}

/**
* Convert a SAM type to the corresponding FunctionType,
* extrapolating BoundedWildcardTypes in the process
* (no type precision is lost by the extrapolation,
* but this facilitates dealing with the types arising from Java's use-site variance).
*/
def samToFunctionType(tp: Type, sam: Symbol = NoSymbol): Type = {
val samSym = sam orElse samOf(tp)

def correspondingFunctionSymbol = {
val numVparams = samSym.info.params.length
if (numVparams > definitions.MaxFunctionArity) NoSymbol
else FunctionClass(numVparams)
}

if (samSym.exists && samSym.owner != correspondingFunctionSymbol) // don't treat Functions as SAMs
wildcardExtrapolation(normalize(tp memberInfo samSym))
else NoType
}

/** Perform the following adaptations of expression, pattern or type `tree` wrt to
* given mode `mode` and given prototype `pt`:
* (-1) For expressions with annotated types, let AnnotationCheckers decide what to do
Expand Down Expand Up @@ -824,7 +844,7 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
case Block(_, tree1) => tree1.symbol
case _ => tree.symbol
}
if (!meth.isConstructor && isFunctionType(pt)) { // (4.2)
if (!meth.isConstructor && (isFunctionType(pt) || samOf(pt).exists)) { // (4.2)
debuglog(s"eta-expanding $tree: ${tree.tpe} to $pt")
checkParamsConvertible(tree, tree.tpe)
val tree0 = etaExpand(context.unit, tree, this)
Expand Down Expand Up @@ -2681,7 +2701,7 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
* `{
* def apply$body(p1: T1, ..., pN: TN): T = body
* new S {
* def apply(p1: T1, ..., pN: TN): T = apply$body(p1,..., pN)
* def apply(p1: T1', ..., pN: TN'): T' = apply$body(p1,..., pN)
* }
* }`
*
Expand All @@ -2691,6 +2711,10 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
*
* The `apply` method is identified by the argument `sam`; `S` corresponds to the argument `samClassTp`,
* and `resPt` is derived from `samClassTp` -- it may be fully defined, or not...
* If it is not fully defined, we derive `samClassTpFullyDefined` by inferring any unknown type parameters.
*
* The types T1' ... TN' and T' are derived from the method signature of the sam method,
* as seen from the fully defined `samClassTpFullyDefined`.
*
* The function's body is put in a method outside of the class definition to enforce scoping.
* S's members should not be in scope in `body`.
Expand All @@ -2702,6 +2726,22 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
* However T must be fully defined before we type the instantiation, as it'll end up as a parent type,
* which must be fully defined. Would be nice to have some kind of mechanism to insert type vars in a block of code,
* and have the instantiation of the first occurrence propagate to the rest of the block.
*
* TODO: by-name params
* scala> trait LazySink { def accept(a: => Any): Unit }
* defined trait LazySink
*
* scala> val f: LazySink = (a) => (a, a)
* f: LazySink = $anonfun$1@1fb26910
*
* scala> f(println("!"))
* <console>:10: error: LazySink does not take parameters
* f(println("!"))
* ^
*
* scala> f.accept(println("!"))
* !
* !
*/
def synthesizeSAMFunction(sam: Symbol, fun: Function, resPt: Type, samClassTp: Type, mode: Mode): Tree = {
// assert(fun.vparams forall (vp => isFullyDefined(vp.tpt.tpe))) -- by construction, as we take them from sam's info
Expand Down Expand Up @@ -2782,14 +2822,21 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
samClassTp
}

// `final override def ${sam.name}($p1: $T1, ..., $pN: $TN): $resPt = ${sam.name}\$body'($p1, ..., $pN)`
// what's the signature of the method that we should actually be overriding?
val samMethTp = samClassTpFullyDefined memberInfo sam
// Before the mutation, `tp <:< vpar.tpt.tpe` should hold.
// TODO: error message when this is not the case, as the expansion won't type check
// - Ti' <:< Ti and T <: T' must hold for the samDef body to type check
val funArgTps = foreach2(samMethTp.paramTypes, fun.vparams)((tp, vpar) => vpar.tpt setType tp)

// `final override def ${sam.name}($p1: $T1', ..., $pN: $TN'): ${samMethTp.finalResultType} = ${sam.name}\$body'($p1, ..., $pN)`
val samDef =
DefDef(Modifiers(FINAL | OVERRIDE | SYNTHETIC),
sam.name.toTermName,
Nil,
List(fun.vparams),
TypeTree(samBodyDef.tpt.tpe) setPos sampos.focus,
Apply(Ident(bodyName), fun.vparams map (p => Ident(p.name)))
TypeTree(samMethTp.finalResultType) setPos sampos.focus,
Apply(Ident(bodyName), fun.vparams map gen.paramToArg)
)

val serializableParentAddendum =
Expand Down Expand Up @@ -2819,6 +2866,11 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
)
}

// TODO: improve error reporting -- when we're in silent mode (from `silent(_.doTypedApply(tree, fun, args, mode, pt)) orElse onError`)
// the errors in the function don't get out...
if (block exists (_.isErroneous))
context.error(fun.pos, s"Could not derive subclass of $samClassTp\n (with SAM `def $sam$samMethTp`)\n based on: $fun.")

classDef.symbol addAnnotation SerialVersionUIDAnnotation
block
}
Expand All @@ -2839,7 +2891,7 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
* as `(a => a): Int => Int` should not (yet) get the sam treatment.
*/
val sam =
if (!settings.Xexperimental || pt.typeSymbol == FunctionSymbol) NoSymbol
if (pt.typeSymbol == FunctionSymbol) NoSymbol
else samOf(pt)

/* The SAM case comes first so that this works:
Expand All @@ -2849,15 +2901,11 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
* Note that the arity of the sam must correspond to the arity of the function.
*/
val samViable = sam.exists && sameLength(sam.info.params, fun.vparams)
val ptNorm = if (samViable) samToFunctionType(pt, sam) else pt
val (argpts, respt) =
if (samViable) {
val samInfo = pt memberInfo sam
(samInfo.paramTypes, samInfo.resultType)
} else {
pt baseType FunctionSymbol match {
case TypeRef(_, FunctionSymbol, args :+ res) => (args, res)
case _ => (fun.vparams map (_ => if (pt == ErrorType) ErrorType else NoType), WildcardType)
}
ptNorm baseType FunctionSymbol match {
case TypeRef(_, FunctionSymbol, args :+ res) => (args, res)
case _ => (fun.vparams map (_ => if (pt == ErrorType) ErrorType else NoType), WildcardType)
}

if (!FunctionSymbol.exists)
Expand Down
2 changes: 1 addition & 1 deletion src/reflect/scala/reflect/internal/Definitions.scala
Expand Up @@ -790,7 +790,7 @@ trait Definitions extends api.StandardDefinitions {
* The class defining the method is a supertype of `tp` that
* has a public no-arg primary constructor.
*/
def samOf(tp: Type): Symbol = {
def samOf(tp: Type): Symbol = if (!settings.Xexperimental) NoSymbol else {
// if tp has a constructor, it must be public and must not take any arguments
// (not even an implicit argument list -- to keep it simple for now)
val tpSym = tp.typeSymbol
Expand Down
16 changes: 16 additions & 0 deletions src/reflect/scala/reflect/internal/tpe/TypeMaps.scala
Expand Up @@ -422,6 +422,22 @@ private[internal] trait TypeMaps {
}
}

/**
* Get rid of BoundedWildcardType where variance allows us to do so.
* Invariant: `wildcardExtrapolation(tp) =:= tp`
*
* For example, the MethodType given by `def bla(x: (_ >: String)): (_ <: Int)`
* is both a subtype and a supertype of `def bla(x: String): Int`.
*/
object wildcardExtrapolation extends TypeMap(trackVariance = true) {
def apply(tp: Type): Type =
tp match {
case BoundedWildcardType(TypeBounds(lo, AnyTpe)) if variance.isContravariant => lo
case BoundedWildcardType(TypeBounds(NothingTpe, hi)) if variance.isCovariant => hi
case tp => mapOver(tp)
}
}

/** Might the given symbol be important when calculating the prefix
* of a type? When tp.asSeenFrom(pre, clazz) is called on `tp`,
* the result will be `tp` unchanged if `pre` is trivial and `clazz`
Expand Down
1 change: 1 addition & 0 deletions src/reflect/scala/reflect/runtime/JavaUniverseForce.scala
Expand Up @@ -170,6 +170,7 @@ trait JavaUniverseForce { self: runtime.JavaUniverse =>
this.dropSingletonType
this.abstractTypesToBounds
this.dropIllegalStarTypes
this.wildcardExtrapolation
this.IsDependentCollector
this.ApproximateDependentMap
this.wildcardToTypeVarMap
Expand Down
6 changes: 6 additions & 0 deletions test/files/neg/sammy_error_exist_no_crash.check
@@ -0,0 +1,6 @@
sammy_error_exist_no_crash.scala:5: error: Could not derive subclass of F[? >: String]
(with SAM `def method apply(s: String)Int`)
based on: ((x$1: String) => x$1.<parseInt: error>).
bar(_.parseInt)
^
one error found
1 change: 1 addition & 0 deletions test/files/neg/sammy_error_exist_no_crash.flags
@@ -0,0 +1 @@
-Xexperimental
6 changes: 6 additions & 0 deletions test/files/neg/sammy_error_exist_no_crash.scala
@@ -0,0 +1,6 @@
abstract class F[T] { def apply(s: T): Int }

object NeedsNiceError {
def bar(x: F[_ >: String]) = ???
bar(_.parseInt)
}
28 changes: 14 additions & 14 deletions test/files/neg/sammy_restrictions.scala
@@ -1,28 +1,28 @@
class NoAbstract
abstract class NoAbstract

class TwoAbstract { def ap(a: Int): Int; def pa(a: Int): Int }
abstract class TwoAbstract { def ap(a: Int): Int; def pa(a: Int): Int }

class Base // check that the super class constructor isn't considered.
class NoEmptyConstructor(a: Int) extends Base { def this(a: String) = this(0); def ap(a: Int): Int }
abstract class Base // check that the super class constructor isn't considered.
abstract class NoEmptyConstructor(a: Int) extends Base { def this(a: String) = this(0); def ap(a: Int): Int }

class OneEmptyConstructor() { def this(a: Int) = this(); def ap(a: Int): Int }
abstract class OneEmptyConstructor() { def this(a: Int) = this(); def ap(a: Int): Int }

class OneEmptySecondaryConstructor(a: Int) { def this() = this(0); def ap(a: Int): Int }
abstract class OneEmptySecondaryConstructor(a: Int) { def this() = this(0); def ap(a: Int): Int }

class MultipleConstructorLists()() { def ap(a: Int): Int }
abstract class MultipleConstructorLists()() { def ap(a: Int): Int }

class MultipleMethodLists()() { def ap(a: Int)(): Int }
abstract class MultipleMethodLists()() { def ap(a: Int)(): Int }

class ImplicitConstructorParam()(implicit a: String) { def ap(a: Int): Int }
abstract class ImplicitConstructorParam()(implicit a: String) { def ap(a: Int): Int }

class ImplicitMethodParam() { def ap(a: Int)(implicit b: String): Int }
abstract class ImplicitMethodParam() { def ap(a: Int)(implicit b: String): Int }

class PolyClass[T] { def ap(a: T): T }
abstract class PolyClass[T] { def ap(a: T): T }

class PolyMethod { def ap[T](a: T): T }
abstract class PolyMethod { def ap[T](a: T): T }

class OneAbstract { def ap(a: Any): Any }
class DerivedOneAbstract extends OneAbstract
abstract class OneAbstract { def ap(a: Int): Any }
abstract class DerivedOneAbstract extends OneAbstract

object Test {
implicit val s: String = ""
Expand Down
1 change: 1 addition & 0 deletions test/files/pos/sammy_exist.flags
@@ -0,0 +1 @@
-Xexperimental
17 changes: 17 additions & 0 deletions test/files/pos/sammy_exist.scala
@@ -0,0 +1,17 @@
// scala> typeOf[java.util.stream.Stream[_]].nonPrivateMember(TermName("map")).info
// [R](x$1: java.util.function.Function[_ >: T, _ <: R])java.util.stream.Stream[R]

// java.util.function.Function
trait Fun[A, B] { def apply(x: A): B }

// java.util.stream.Stream
class S[T](x: T) { def map[R](f: Fun[_ >: T, _ <: R]): R = f(x) }

class Bla { def foo: Bla = this }

// NOTE: inferred types show unmoored skolems, should pack them to display properly as bounded wildcards
object T {
val aBlaSAM = (new S(new Bla)).map(_.foo)
val fun: Fun[Bla, Bla] = (x: Bla) => x
val aBlaSAMX = (new S(new Bla)).map(fun)
}
1 change: 1 addition & 0 deletions test/files/pos/sammy_overload.flags
@@ -0,0 +1 @@
-Xexperimental
9 changes: 9 additions & 0 deletions test/files/pos/sammy_overload.scala
@@ -0,0 +1,9 @@
trait Consumer[T] {
def consume(x: T): Unit
}

object Test {
def foo(x: String): Unit = ???
def foo(): Unit = ???
val f: Consumer[_ >: String] = foo
}
1 change: 1 addition & 0 deletions test/files/pos/sammy_override.flags
@@ -0,0 +1 @@
-Xexperimental
8 changes: 8 additions & 0 deletions test/files/pos/sammy_override.scala
@@ -0,0 +1,8 @@
trait IntConsumer {
def consume(x: Int): Unit
}

object Test {
def anyConsumer(x: Any): Unit = ???
val f: IntConsumer = anyConsumer
}
1 change: 1 addition & 0 deletions test/files/pos/t8310.flags
@@ -0,0 +1 @@
-Xexperimental
22 changes: 22 additions & 0 deletions test/files/pos/t8310.scala
@@ -0,0 +1,22 @@
trait Comparinator[T] { def compare(a: T, b: T): Int }

object TestOkay {
def sort(x: Comparinator[_ >: String]) = ()
sort((a: String, b: String) => a.compareToIgnoreCase(b))
}

object TestOkay2 {
def sort[T](x: Comparinator[_ >: T]) = ()
sort((a: String, b: String) => a.compareToIgnoreCase(b))
}

object TestOkay3 {
def sort[T](xs: Option[T], x: Comparinator[_ >: T]) = ()
sort(Some(""), (a: String, b: String) => a.compareToIgnoreCase(b))
}

object TestKoOverloaded {
def sort[T](xs: Option[T]) = ()
def sort[T](xs: Option[T], x: Comparinator[_ >: T]) = ()
sort(Some(""), (a: String, b: String) => a.compareToIgnoreCase(b))
}
1 change: 1 addition & 0 deletions test/files/run/sammy_repeated.check
@@ -0,0 +1 @@
WrappedArray(1)
1 change: 1 addition & 0 deletions test/files/run/sammy_repeated.flags
@@ -0,0 +1 @@
-Xexperimental
8 changes: 8 additions & 0 deletions test/files/run/sammy_repeated.scala
@@ -0,0 +1,8 @@
trait RepeatedSink { def accept(a: Any*): Unit }

object Test {
def main(args: Array[String]): Unit = {
val f: RepeatedSink = (a) => println(a)
f.accept(1)
}
}

0 comments on commit 5a7875f

Please sign in to comment.