diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index fa952bae91da..7736a63cdfde 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1921,15 +1921,39 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer NoType } - pt match { - case pt: TypeVar - if untpd.isFunctionWithUnknownParamType(tree) && !calleeType.exists => - // try to instantiate `pt` if this is possible. If it does not - // work the error will be reported later in `inferredParam`, - // when we try to infer the parameter type. - isFullyDefined(pt, ForceDegree.flipBottom) - case _ => - } + /** Try to instantiate one type variable bounded by function types that appear + * deeply inside `tp`, including union or intersection types. + */ + def tryToInstantiateDeeply(tp: Type): Boolean = tp match + case tp: AndOrType => + tryToInstantiateDeeply(tp.tp1) + || tryToInstantiateDeeply(tp.tp2) + case tp: FlexibleType => + tryToInstantiateDeeply(tp.hi) + case tp: TypeVar if isConstrainedByFunctionType(tp) => + // Only instantiate if the type variable is constrained by function types + isFullyDefined(tp, ForceDegree.flipBottom) + case _ => false + + def isConstrainedByFunctionType(tvar: TypeVar): Boolean = + val origin = tvar.origin + val bounds = ctx.typerState.constraint.bounds(origin) + def containsFunctionType(tp: Type): Boolean = tp.dealias match + case tp if defn.isFunctionType(tp) => true + case SAMType(_, _) => true + case tp: AndOrType => + containsFunctionType(tp.tp1) || containsFunctionType(tp.tp2) + case tp: FlexibleType => + containsFunctionType(tp.hi) + case _ => false + containsFunctionType(bounds.lo) || containsFunctionType(bounds.hi) + + if untpd.isFunctionWithUnknownParamType(tree) && !calleeType.exists then + // Try to instantiate `pt` when possible, by searching a nested type variable + // bounded by function types to help infer parameter types. + // If it does not work the error will be reported later in `inferredParam`, + // when we try to infer the parameter type. + tryToInstantiateDeeply(pt) val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length, tree.srcPos) diff --git a/tests/neg-custom-args/captures/i15923.check b/tests/neg-custom-args/captures/i15923.check index 3e4a97509237..60bdf2236008 100644 --- a/tests/neg-custom-args/captures/i15923.check +++ b/tests/neg-custom-args/captures/i15923.check @@ -6,7 +6,7 @@ | |Note that capability lcap cannot be included in outer capture set 's1 of parameter cap. | - |where: => refers to a fresh root capability created in anonymous function of type (using lcap: scala.caps.Capability): test2.Cap^{lcap} -> [T] => (op: test2.Cap^{lcap} => T) -> T when instantiating expected result type test2.Cap^{lcap} ->{cap²} [T] => (op: test2.Cap^'s6 ->'s7 T) ->'s8 T of function literal + |where: => refers to a fresh root capability created in anonymous function of type (using lcap: scala.caps.Capability): test2.Cap^{lcap} -> [T] => (op: test2.Cap => T) -> T when instantiating expected result type test2.Cap^{lcap} ->{cap²} [T] => (op: test2.Cap^'s6 ->'s7 T) ->'s8 T of function literal | | longer explanation available when compiling with `-explain` -- [E007] Type Mismatch Error: tests/neg-custom-args/captures/i15923.scala:12:21 --------------------------------------- diff --git a/tests/pos/infer-function-type-in-union.scala b/tests/pos/infer-function-type-in-union.scala new file mode 100644 index 000000000000..f631761b3897 --- /dev/null +++ b/tests/pos/infer-function-type-in-union.scala @@ -0,0 +1,42 @@ + +def f[T](x: T): T = ??? +def f2[T](x: T | T): T = ??? +def f3[T](x: T | Null): T = ??? +def f4[T](x: Int | T): T = ??? + +trait MyOption[+T] + +object MyOption: + def apply[T](x: T | Null): MyOption[T] = ??? + +def test = + val g: AnyRef => Boolean = f { + x => x eq null // ok + } + val g2: AnyRef => Boolean = f2 { + x => x eq null // ok + } + val g3: AnyRef => Boolean = f3 { + x => x eq null // was error + } + val g4: AnyRef => Boolean = f4 { + x => x eq null // was error + } + + val o1: MyOption[String] = MyOption(null) + val o2: MyOption[String => Boolean] = MyOption { + x => x.length > 0 + } + val o3: MyOption[(String, String) => Boolean] = MyOption { + (x, y) => x.length > y.length + } + + +class Box[T] +val box: Box[Unit] = ??? +def ff1[T, U](x: T | U, y: Box[U]): T = ??? +def ff2[T, U](x: T & U): T = ??? + +def test2 = + val a1: Any => Any = ff1(x => x, box) + val a2: Any => Any = ff2(x => x) \ No newline at end of file