Skip to content

Commit

Permalink
More precise prototype for args of overloaded method
Browse files Browse the repository at this point in the history
Normally, overload resolution types the arguments to the alternatives
without an expected type. However, typing function literals and
eta-expansion are driven by the expected type:

  - function literals usually don't have parameter types, which are
    derived from the expected type;

  - eta-expansion right now only happens when a function/sam type is
    expected.

(Dotty side-steps these issues by eta-expanding regardless of
expected type.)

Now that the collections are full of overloaded HO methods, we should
try harder to type check them nicely.

To avoid breaking existing code, we only provide an expected type (for
each argument position) when:

 - there is at least one FunctionN type expected by one of the
   overloads: in this case, the expected type is a FunctionN[Ti, ?],
   where Ti are the argument types (they must all be =:=), and the
   expected result type is elided using a wildcard. This does not
   exclude any overloads that expect a SAM, because they conform to a
   function type through SAM conversion

 - OR: all overloads expect a SAM type of the same class, but with
   potentially varying result types (argument types must be =:=)

We allow polymorphic cases, as long as the types parameters are
instantiated by the AntiPolyType prefix.

In all other cases, the old behavior is maintained: Wildcard is
expected.

(Slightly) more formally:

Consider an overloaded method `m_i`, with `N` overloads `i = 1..N`,
and an expected argument type at index `j`, `a_ij`:

```
def m_1(... a_1j, ...)
..
def m_N(... a_Nj, ...)
```

Any polymorphic method `m_i` will be reduced to the monomorphic case
by pushing down the method's `PolyType` to its arguments `a_ij`.

The expected type for the argument at index `j` will be more
precise than the usual `WildcardType` (`?`), if all types `a_1j..a_Nj`
are function-ish types that denote the same parameter types `p1..pM`.

A "function-ish" type is a `FunctionN[p1,...,pM]` (or
`PartialFunction`), or the equivalent SAM type. (We first unwrap any
PolyTypes.)

The non-wildcard expected type will be
  - `PartialFunction[p1, ?]`, if an `a_ij` expects a partial function;
  - else, if there is a subclass of `FunctionM` among the `a_ij`,
    it is `FunctionM[p1, pM, ?]`;
  - else, if all `a_ij` are of the same SAM type (allowing
    for varying result types), that SAM type (with `?` result type).

In each case, any type parameter not already resolved by overloading,
or the outer context, it approximated by `?`.

PS: type equivalence is decided as `tp1 <:< tp2 && tp2 <:< tp1`,
and not `tp1 =:= tp2` (the latter is actually stricter).
  • Loading branch information
adriaanm committed Aug 10, 2018
1 parent d09589d commit cd70540
Show file tree
Hide file tree
Showing 9 changed files with 314 additions and 69 deletions.
69 changes: 9 additions & 60 deletions src/compiler/scala/tools/nsc/typechecker/Typers.scala
Expand Up @@ -2550,21 +2550,17 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
def synthesizePartialFunction(paramName: TermName, paramPos: Position, paramSynthetic: Boolean,
tree: Tree, mode: Mode, pt: Type): Tree = {
assert(pt.typeSymbol == PartialFunctionClass, s"PartialFunction synthesis for match in $tree requires PartialFunction expected type, but got $pt.")
val targs = partialFunctionArgTypeFromProto(pt)

// if targs.head isn't fully defined, we can't translate --> error
targs match {
case argTp :: _ if isFullyDefined(argTp) => // ok
case _ => // uh-oh
MissingParameterTypeAnonMatchError(tree, pt)
return setError(tree)
}
val (argTp, resTp) = partialFunctionArgResTypeFromProto(pt)

// if argTp isn't fully defined, we can't translate --> error
// NOTE: resTp still might not be fully defined
val argTp :: resTp :: Nil = targs
if (!isFullyDefined(argTp)) {
MissingParameterTypeAnonMatchError(tree, pt)
return setError(tree)
}

// targs must conform to Any for us to synthesize an applyOrElse (fallback to apply otherwise -- typically for @cps annotated targs)
val targsValidParams = targs forall (_ <:< AnyTpe)
val targsValidParams = (argTp <:< AnyTpe) && (resTp <:< AnyTpe)

val anonClass = context.owner newAnonymousFunctionClass tree.pos addAnnotation SerialVersionUIDAnnotation

Expand Down Expand Up @@ -3314,59 +3310,12 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
def handleOverloaded = {
val undetparams = context.undetparams

def funArgTypes(tpAlts: List[(Type, Symbol)]) = tpAlts.map { case (tp, alt) =>
val relTp = tp.asSeenFrom(pre, alt.owner)
functionOrPfOrSamArgTypes(relTp)
}

def functionProto(argTpWithAlt: List[(Type, Symbol)]): Type =
try functionType(funArgTypes(argTpWithAlt).transpose.map(lub), WildcardType)
catch { case _: IllegalArgumentException => WildcardType }

def partialFunctionProto(argTpWithAlt: List[(Type, Symbol)]): Type =
try appliedType(PartialFunctionClass, funArgTypes(argTpWithAlt).transpose.map(lub) :+ WildcardType)
catch { case _: IllegalArgumentException => WildcardType }

// To propagate as much information as possible to typedFunction, which uses the expected type to
// infer missing parameter types for Function trees that we're typing as arguments here,
// we expand the parameter types for all alternatives to the expected argument length,
// then transpose to get a list of alternative argument types (push down the overloading to the arguments).
// Thus, for each `arg` in `args`, the corresponding `argPts` in `altArgPts` is a list of expected types
// for `arg`. Depending on which overload is picked, only one of those expected types must be met, but
// we're in the process of figuring that out, so we'll approximate below by normalizing them to function types
// and lubbing the argument types (we treat SAM and FunctionN types equally, but non-function arguments
// do not receive special treatment: they are typed under WildcardType.)
val altArgPts =
if (settings.isScala212 && args.exists(t => treeInfo.isFunctionMissingParamType(t) || treeInfo.isPartialFunctionMissingParamType(t)))
try alts.map { alt =>
val paramTypes = pre.memberType(alt) match {
case mt @ MethodType(_, _) => mt.paramTypes
case PolyType(_, mt @ MethodType(_, _)) => mt.paramTypes
case t => throw new RuntimeException("Expected MethodType or PolyType of MethodType, got "+t)
}
formalTypes(paramTypes, argslen).map(ft => (ft, alt))
}.transpose // do least amount of work up front
catch { case _: IllegalArgumentException => args.map(_ => Nil) } // fail safe in case formalTypes fails to align to argslen
else args.map(_ => Nil) // will type under argPt == WildcardType

val (args1, argTpes) = context.savingUndeterminedTypeParams() {
val amode = forArgMode(fun, mode)

map2(args, altArgPts) { (arg, argPtAlts) =>
mapWithIndex(args) { (arg, argIdx) =>
def typedArg0(tree: Tree) = {
// if we have an overloaded HOF such as `(f: Int => Int)Int <and> (f: Char => Char)Char`,
// and we're typing a function like `x => x` for the argument, try to collapse
// the overloaded type into a single function type from which `typedFunction`
// can derive the argument type for `x` in the function literal above
val argPt =
if (argPtAlts.isEmpty) WildcardType
else if (treeInfo.isFunctionMissingParamType(tree)) functionProto(argPtAlts)
else if (treeInfo.isPartialFunctionMissingParamType(tree)) {
if (argPtAlts.exists(ts => isPartialFunctionType(ts._1))) partialFunctionProto(argPtAlts)
else functionProto(argPtAlts)
} else WildcardType

val argTyped = typedArg(tree, amode, BYVALmode, argPt)
val argTyped = typedArg(tree, amode, BYVALmode, OverloadedArgFunProto(argIdx, pre, alts))
(argTyped, argTyped.tpe.deconst)
}

Expand Down
7 changes: 5 additions & 2 deletions src/reflect/scala/reflect/internal/Definitions.scala
Expand Up @@ -723,9 +723,12 @@ trait Definitions extends api.StandardDefinitions {
)

// @requires pt.typeSymbol == PartialFunctionClass
def partialFunctionArgTypeFromProto(pt: Type) =
def partialFunctionArgResTypeFromProto(pt: Type): (Type, Type) =
pt match {
case _ => pt.dealiasWiden.typeArgs
case oap: OverloadedArgFunProto => (oap.hofParamTypes.head, WildcardType)
case _ =>
val arg :: res :: Nil = pt.baseType(PartialFunctionClass).typeArgs
(arg, res)
}

// the number of arguments expected by the function described by `tp` (a FunctionN or SAM type),
Expand Down
136 changes: 136 additions & 0 deletions src/reflect/scala/reflect/internal/Types.scala
Expand Up @@ -1184,6 +1184,142 @@ trait Types
def toVariantType: Type = NoType
}

/** Help infer parameter types for function arguments to overloaded methods.
*
* Normally, overload resolution types the arguments to the alternatives without an expected type.
* However, typing function literals and eta-expansion are driven by the expected type:
* - function literals usually don't have parameter types, which are derived from the expected type;
* - eta-expansion right now only happens when a function/sam type is expected.
*
* Now that the collections are full of overloaded HO methods, we should try harder to type check them nicely.
*
* To avoid breaking existing code, we only provide an expected type (for each argument position) when:
* - there is at least one FunctionN type expected by one of the overloads:
* in this case, the expected type is a FunctionN[Ti, ?], where Ti are the argument types (they must all be =:=),
* and the expected result type is elided using a wildcard.
* This does not exclude any overloads that expect a SAM, because they conform to a function type through SAM conversion
* - OR: all overloads expect a SAM type of the same class, but with potentially varying result types (argument types must be =:=)
*
* We allow polymorphic cases, as long as the types parameters are instantiated by the AntiPolyType prefix.
*
* In all other cases, the old behavior is maintained: Wildcard is expected.
*/
case class OverloadedArgFunProto(argIdx: Int, pre: Type, alternatives: List[Symbol]) extends ProtoType with SimpleTypeProxy {
override def safeToString: String = underlying.safeToString
override def kind = "OverloadedArgFunProto"

override def underlying: Type = functionArgsProto

// Always match if we couldn't collapse the expected types contributed for this argument by the alternatives.
// TODO: could we just match all function-ish types as an optimization? We previously used WildcardType
override def isMatchedBy(tp: Type, depth: Depth): Boolean =
isPastTyper || underlying == WildcardType ||
isSubType(tp, underlying, depth) ||
// NOTE: converting tp to a function type won't work, since `tp` need not be an actual sam type,
// just some subclass of the sam expected by one of our overloads
sameTypesFoldedSam.exists { underlyingSam => isSubType(tp, underlyingSam, depth) } // overload_proto_collapse.scala:55

// Empty signals failure. We don't consider the 0-ary HOF case, since we are only concerned with inferring param types for these functions anyway
def hofParamTypes = functionOrPfOrSamArgTypes(underlying)

override def expectsFunctionType: Boolean = hofParamTypes.nonEmpty

// TODO: include result type?
override def asFunctionType =
if (expectsFunctionType) functionType(hofParamTypes, WildcardType)
else NoType

override def mapOver(map: TypeMap): Type = {
val pre1 = pre.mapOver(map)
val alts1 = map.mapOver(alternatives)
if ((pre ne pre1) || (alternatives ne alts1)) OverloadedArgFunProto(argIdx, pre1, alts1)
else this
}

// TODO
// override def registerTypeEquality(tp: Type): Boolean = functionArgsProto =:= tp


// TODO: use =:=, but `!(typeOf[String with AnyRef] =:= typeOf[String])` (https://github.com/scala/scala-dev/issues/530)
private def same(x: Type, y: Type) = (x <:< y) && (y <:< x)

private object ParamAtIdx {
def unapply(params: List[Symbol]): Option[Type] = {
lazy val lastParamTp = params.last.tpe

// if we're asking for the last argument, or past, and it happens to be a repeated param -- strip the vararg marker and return the type
if (params.nonEmpty && params.lengthCompare(argIdx + 1) <= 0 && isRepeatedParamType(lastParamTp)) {
Some(lastParamTp.dealiasWiden.typeArgs.head)
} else if (params.isDefinedAt(argIdx)) {
Some(params(argIdx).tpe)
} else None
}
}


private def toWild(tp: Type): Type = tp match {
case PolyType(tparams, tp) => new SubstWildcardMap(tparams).apply(tp)
case tp => tp
}

private lazy val sameTypesFolded = {
// Collect all expected types contributed by the various alternatives for this argument (TODO: repeated params?)
// Relative to `pre` at `alt.owner`, with `alt`'s type params approximated.
val altParamTps =
alternatives map { alt =>
// Use memberType so that a pre: AntiPolyType can instantiate its type params
pre.memberType(alt) match {
case PolyType(tparams, MethodType(ParamAtIdx(paramTp), res)) => PolyType(tparams, paramTp.asSeenFrom(pre, alt.owner))
case MethodType(ParamAtIdx(paramTp), res) => paramTp.asSeenFrom(pre, alt.owner)
case _ => NoType
}
}

altParamTps.foldLeft(Nil: List[Type]) {
case (acc, NoType | WildcardType) => acc
case (acc, tp) => if (acc.exists(same(tp, _))) acc else tp :: acc
}
}

private lazy val sameTypesFoldedSam =
sameTypesFolded.iterator.map(toWild).filter(tp => samOf(tp).exists).toList

// Try to collapse all expected argument types (already distinct by =:=) into a single expected type,
// so that we can use it to as the expected type to drive parameter type inference for a function literal argument.
private lazy val functionArgsProto = {
val ABORT = (NoType, false, false)

// we also consider any function-ish type equal as long as the argument types are
def sameHOArgTypes(tp1: Type, tp2: Type) = tp1 == WildcardType || {
val hoArgTypes1 = functionOrPfOrSamArgTypes(tp1.resultType)
// println(s"sameHOArgTypes($tp1, $tp2) --> $hoArgTypes1 === $hoArgTypes2 : $res")
hoArgTypes1.nonEmpty && hoArgTypes1.corresponds(functionOrPfOrSamArgTypes(tp2.resultType))(same)
}

// TODO: compute functionOrPfOrSamArgTypes during fold?
val (sameHoArgTypesFolded, partialFun, regularFun) =
sameTypesFolded.foldLeft((WildcardType: Type, false, false)) {
case (ABORT, _) => ABORT
case ((acc, partialFun, regularFun), tp) if sameHOArgTypes(acc, tp) =>
val wild = toWild(tp)
(tp, partialFun || isPartialFunctionType(wild), regularFun || isFunctionType(wild))
case _ => ABORT // different HO argument types encountered
}

if ((sameHoArgTypesFolded eq WildcardType) || (sameHoArgTypesFolded eq NoType)) WildcardType
else functionOrPfOrSamArgTypes(toWild(sameHoArgTypesFolded)) match {
case Nil => WildcardType // TODO: can we retain some of this?
case hofArgs =>
if (partialFun) appliedType(PartialFunctionClass, hofArgs :+ WildcardType)
else if (regularFun) functionType(hofArgs, WildcardType)
// if we saw a variety of SAMs, can't collapse them -- what if they were accidental sams and we're not going to supply a function literal?
else if (sameTypesFolded.lengthCompare(1) == 0) toWild(sameTypesFolded.head)
else WildcardType
}

}
}

/** An object representing a non-existing type */
case object NoType extends Type {
override def isTrivial: Boolean = true
Expand Down
1 change: 1 addition & 0 deletions src/reflect/scala/reflect/runtime/JavaUniverseForce.scala
Expand Up @@ -142,6 +142,7 @@ trait JavaUniverseForce { self: runtime.JavaUniverse =>
this.ErrorType
this.WildcardType
this.BoundedWildcardType
this.OverloadedArgFunProto
this.NoType
this.NoPrefix
this.ThisType
Expand Down
7 changes: 2 additions & 5 deletions test/files/neg/t6214.check
@@ -1,7 +1,4 @@
t6214.scala:5: error: ambiguous reference to overloaded definition,
both method m in object Test of type (f: Int => Unit)Int
and method m in object Test of type (f: String => Unit)Int
match argument types (Any => Unit)
t6214.scala:5: error: missing parameter type
m { s => case class Foo() }
^
^
one error found
95 changes: 95 additions & 0 deletions test/files/pos/overload_proto.scala
@@ -0,0 +1,95 @@
object Util {
def mono(x: Int) = x
def poly[T](x: T): T = x
}

trait FunSam[-T, +R] { def apply(x: T): R }


trait TFun { def map[T](f: T => Int): Unit = () }
object Fun extends TFun { import Util._
def map[T: scala.reflect.ClassTag](f: T => Int): Unit = ()

map(mono)
map(mono _)
map(x => mono(x))

// can't infer polymorphic type for function parameter:
// map(poly)
// map(poly _)
// map(x => poly(x))
}

trait TSam { def map[T](f: T FunSam Int): Unit = () }
object Sam extends TSam { import Util._
def map[T: scala.reflect.ClassTag](f: T `FunSam` Int): Unit = ()

map(mono) // sam
map(mono _) // sam
map(x => mono(x)) // sam

// can't infer polymorphic type for function parameter:
// map(poly)
// map(poly _)
// map(x => poly(x))
}

trait IntFun { def map[T](f: Int => T): Unit = () }
object int_Fun extends IntFun { import Util._
def map[T: scala.reflect.ClassTag](f: Int => T): Unit = ()

map(mono)
map(mono _)
map(x => mono(x))

map(poly)
map(poly _)
map(x => poly(x))
}

trait IntSam { def map[T](f: Int FunSam T): Unit = () }
object int_Sam extends IntSam { import Util._
def map[T: scala.reflect.ClassTag](f: Int `FunSam` T): Unit = ()

map(mono) // sam
map(mono _) // sam
map(x => mono(x)) // sam

map(poly) // sam
map(poly _) // sam
map(x => poly(x)) // sam
}


/*
eta_overload_hof.scala:27: error: missing argument list for method mono in object Util
Unapplied methods are only converted to functions when a function type is expected.
You can make this conversion explicit by writing `mono _` or `mono(_)` instead of `mono`.
map(mono)
^
eta_overload_hof.scala:46: error: type mismatch;
found : Nothing => Nothing
required: ?<: Int => ?
map(poly _)
^
eta_overload_hof.scala:54: error: missing argument list for method mono in object Util
Unapplied methods are only converted to functions when a function type is expected.
You can make this conversion explicit by writing `mono _` or `mono(_)` instead of `mono`.
map(mono)
^
eta_overload_hof.scala:58: error: missing argument list for method poly in object Util
Unapplied methods are only converted to functions when a function type is expected.
You can make this conversion explicit by writing `poly _` or `poly(_)` instead of `poly`.
map(poly)
^
eta_overload_hof.scala:59: error: overloaded method value map with alternatives:
[T](f: FunSam[Int,T])(implicit evidence$4: scala.reflect.ClassTag[T])Unit <and>
[T](f: FunSam[Int,T])Unit
cannot be applied to (Nothing => Nothing)
map(poly _)
^
eta_overload_hof.scala:60: error: missing parameter type
map(x => poly(x))
^
* */
7 changes: 7 additions & 0 deletions test/files/pos/overload_proto_accisam.scala
@@ -0,0 +1,7 @@
// TODO make independent of java.io.OutputStream, but obvious way does not capture the bug (see didInferSamType and OverloadedArgFunProto)
class Test {
def overloadedAccidentalSam(a: java.io.OutputStream, b: String) = ???
def overloadedAccidentalSam(a: java.io.OutputStream, b: Any)= ???

overloadedAccidentalSam(??? : java.io.OutputStream, null)
}

0 comments on commit cd70540

Please sign in to comment.