Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
typedFunction undoes eta-expansion regardless of expected type
When recovering missing argument types for an
eta-expanded method value, rework the expected type
to a method type.
  • Loading branch information
adriaanm committed Mar 31, 2016
1 parent 8e32d00 commit 5d7d644
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 45 deletions.
15 changes: 0 additions & 15 deletions src/compiler/scala/tools/nsc/typechecker/EtaExpansion.scala
Expand Up @@ -15,23 +15,8 @@ import symtab.Flags._
* @version 1.0
*/
trait EtaExpansion { self: Analyzer =>

import global._

object etaExpansion {
private def isMatch(vparam: ValDef, arg: Tree) = arg match {
case Ident(name) => vparam.name == name
case _ => false
}

def unapply(tree: Tree): Option[(List[ValDef], Tree, List[Tree])] = tree match {
case Function(vparams, Apply(fn, args)) if (vparams corresponds args)(isMatch) =>
Some((vparams, fn, args))
case _ =>
None
}
}

/** <p>
* Expand partial function applications of type `type`.
* </p><pre>
Expand Down
61 changes: 31 additions & 30 deletions src/compiler/scala/tools/nsc/typechecker/Typers.scala
Expand Up @@ -2841,7 +2841,8 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
* - a type with a Single Abstract Method (under -Xexperimental for now).
*/
private def typedFunction(fun: Function, mode: Mode, pt: Type): Tree = {
val numVparams = fun.vparams.length
val vparams = fun.vparams
val numVparams = vparams.length
val FunctionSymbol =
if (numVparams > definitions.MaxFunctionArity) NoSymbol
else FunctionClass(numVparams)
Expand All @@ -2863,37 +2864,20 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
* TODO: handle vararg sams?
*/
val ptNorm =
if (samMatchesFunctionBasedOnArity(sam, fun.vparams)) samToFunctionType(pt, sam)
if (samMatchesFunctionBasedOnArity(sam, vparams)) samToFunctionType(pt, sam)
else pt
val (argpts, respt) =
ptNorm baseType FunctionSymbol match {
case TypeRef(_, FunctionSymbol, args :+ res) => (args, res)
case _ => (fun.vparams map (if (pt == ErrorType) (_ => ErrorType) else (_ => NoType)), WildcardType)
case _ => (vparams map (if (pt == ErrorType) (_ => ErrorType) else (_ => NoType)), WildcardType)
}


// if the function is `(a1: T1, ..., aN: TN) => fun(a1,..., aN)`, where Ti are not all fully defined,
// type `fun` directly
def typeUnEtaExpanded: Type = fun match {
case etaExpansion(_, fn, _) =>
silent(_.typed(fn, mode.forFunMode, pt)) filter (_ => context.undetparams.isEmpty) map { fn1 =>
// if context.undetparams is not empty, the function was polymorphic,
// so we need the missing arguments to infer its type. See #871
val ftpe = normalize(fn1.tpe) baseType FunctionClass(numVparams)
// println(s"typeUnEtaExpanded $fn : ${fn1.tpe} (unwrapped $fun) --> normalized: $ftpe")

if (isFunctionType(ftpe) && isFullyDefined(ftpe)) ftpe
else NoType
} orElse { _ => NoType }
case _ => NoType
}

if (!FunctionSymbol.exists) MaxFunctionArityError(fun)
else if (argpts.lengthCompare(numVparams) != 0) WrongNumberOfParametersError(fun, argpts)
else {
val paramsMissingType = mutable.ArrayBuffer.empty[ValDef] //.sizeHint(numVparams) probably useless, since initial size is 16 and max fun arity is 22
// first, try to define param types from expected function's arg types if needed
foreach2(fun.vparams, argpts) { (vparam, argpt) =>
foreach2(vparams, argpts) { (vparam, argpt) =>
if (vparam.tpt isEmpty) {
if (isFullyDefined(argpt)) vparam.tpt setType argpt
else paramsMissingType += vparam
Expand All @@ -2902,12 +2886,29 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
}
}

// if we had missing param types, see if we can undo eta-expansion and recover type info
val expectedFunTypeBeforeEtaExpansion =
if (paramsMissingType.isEmpty) NoType
else typeUnEtaExpanded
// If we're typing `(a1: T1, ..., aN: TN) => m(a1,..., aN)`, where some Ti are not fully defined,
// type `m` directly (undoing eta-expansion of method m) to determine the argument types.
val ptUnrollingEtaExpansion =
if (paramsMissingType.nonEmpty && pt != ErrorType) fun.body match {
case Apply(meth, args) if (vparams corresponds args) { case (p, Ident(name)) => p.name == name case _ => false } =>
val methArgs = NoSymbol.newSyntheticValueParams(argpts map { case NoType => WildcardType case tp => tp })
// we're looking for a method (as indicated by FUNmode), so let's make sure our expected type is a MethodType
val methPt = MethodType(methArgs, respt)

silent(_.typed(meth, mode.forFunMode, methPt)) filter (_ => context.undetparams.isEmpty) map { methTyped =>
// if context.undetparams is not empty, the method was polymorphic,
// so we need the missing arguments to infer its type. See #871
val funPt = normalize(methTyped.tpe) baseType FunctionClass(numVparams)
// println(s"typeUnEtaExpanded $meth : ${methTyped.tpe} --> normalized: $funPt")

if (isFunctionType(funPt) && isFullyDefined(funPt)) funPt
else null
} orElse { _ => null }
case _ => null
} else null


if (expectedFunTypeBeforeEtaExpansion ne NoType) typedFunction(fun, mode, expectedFunTypeBeforeEtaExpansion)
if (ptUnrollingEtaExpansion ne null) typedFunction(fun, mode, ptUnrollingEtaExpansion)
else {
// we ran out of things to try, missing parameter types are an irrevocable error
var issuedMissingParameterTypeError = false
Expand All @@ -2925,24 +2926,24 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
// thus, its symbol, which serves as the current context.owner, is not the right owner
// you won't know you're using the wrong owner until lambda lift crashes (unless you know better than to use the wrong owner)
val outerTyper = newTyper(context.outer)
val p = fun.vparams.head
val p = vparams.head
if (p.tpt.tpe == null) p.tpt setType outerTyper.typedType(p.tpt).tpe

outerTyper.synthesizePartialFunction(p.name, p.pos, paramSynthetic = false, fun.body, mode, pt)

case _ =>
val vparamSyms = fun.vparams map { vparam =>
val vparamSyms = vparams map { vparam =>
enterSym(context, vparam)
if (context.retyping) context.scope enter vparam.symbol
vparam.symbol
}
val vparams = fun.vparams mapConserve typedValDef
val vparamsTyped = vparams mapConserve typedValDef
val formals = vparamSyms map (_.tpe)
val body1 = typed(fun.body, respt)
val restpe = packedType(body1, fun.symbol).deconst.resultType
val funtpe = phasedAppliedType(FunctionSymbol, formals :+ restpe)

treeCopy.Function(fun, vparams, body1) setType funtpe
treeCopy.Function(fun, vparamsTyped, body1) setType funtpe
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions test/files/pos/fun_undo_eta.scala
@@ -0,0 +1,10 @@
class Test {
def m(i: Int) = i

def expectWild[A](f: A) = ???
def expectFun[A](f: A => Int) = ???

expectWild((i => m(i))) // manual eta expansion
expectWild(m(_)) // have to undo eta expansion with wildcard expected type
expectFun(m(_)) // have to undo eta expansion with function expected type
}

0 comments on commit 5d7d644

Please sign in to comment.