From 0182e069c0d2a941e3061c8834c1a72892dbbf7f Mon Sep 17 00:00:00 2001 From: Guillaume Martres Date: Fri, 15 Oct 2021 12:01:25 +0200 Subject: [PATCH] Fix improper usage of `constrained` breaking type inference In multiple places, we had code equivalent to the following pattern: val (tl2, targs) = constrained(tl) tl2.resultType <:< ... which lead to subtype checks directly involving the TypeParamRefs of the constrained type lambda. This commit uses the following pattern instead: val (tl2, targs) = constrained(tl) tl2.instantiate(targs.map(_.tpe)) <:< ... which substitutes the TypeParamRefs by the corresponding TypeVars in the constraint. This is necessary because when comparing TypeParamRefs in isSubType: - we only recurse on the bounds of the TypeParamRef using `isSubTypeWhenFrozen` which prevents further constraints from being added (see the added stm.scala test case for an example where this matters). - if the corresponding TypeVar is instantiated and the TyperState has been gc()'ed, there is no way to find the instantiation corresponding to the current TypeParamRef anymore. There is one place where I left the old logic intact: `TrackingTypeComparer#matchCase` because the match type caching logic (in `MatchType#reduced`) conflicted with the use of TypeVars since it retracts the current TyperState. This change breaks a test which involves an unlikely combination of implicit conversion, overloading and apply insertion. Given that there is always a tension between type inference and implicit conversion, and that we're discouraging uses of implicit conversions, I think that's an acceptable trade-off. --- compiler/src/dotty/tools/dotc/typer/Applications.scala | 6 +++--- compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala | 9 +++++++-- .../test/dotty/tools/dotc/core/ConstraintsTest.scala | 9 +++++---- tests/pos/stm.scala | 10 ++++++++++ tests/pos/t0851.scala | 3 +-- tests/pos/t2913.scala | 3 +-- 6 files changed, 27 insertions(+), 13 deletions(-) create mode 100644 tests/pos/stm.scala diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index ab7187f6f4d4..b59eb45df771 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -411,7 +411,7 @@ trait Applications extends Compatibility { */ @threadUnsafe lazy val methType: Type = liftedFunType.widen match { case funType: MethodType => funType - case funType: PolyType => constrained(funType).resultType + case funType: PolyType => instantiateWithTypeVars(funType) case tp => tp //was: funType } @@ -1571,7 +1571,7 @@ trait Applications extends Compatibility { case tp2: MethodType => true // (3a) case tp2: PolyType if tp2.resultType.isInstanceOf[MethodType] => true // (3a) case tp2: PolyType => // (3b) - explore(isAsSpecificValueType(tp1, constrained(tp2).resultType)) + explore(isAsSpecificValueType(tp1, instantiateWithTypeVars(tp2))) case _ => // 3b) isAsSpecificValueType(tp1, tp2) } @@ -1738,7 +1738,7 @@ trait Applications extends Compatibility { resultType.revealIgnored match { case resultType: ValueType => altType.widen match { - case tp: PolyType => resultConforms(altSym, constrained(tp).resultType, resultType) + case tp: PolyType => resultConforms(altSym, instantiateWithTypeVars(tp), resultType) case tp: MethodType => constrainResult(altSym, tp.resultType, resultType) case _ => true } diff --git a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala index 167128be08e4..fcedd8a0cd56 100644 --- a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala +++ b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala @@ -659,6 +659,11 @@ object ProtoTypes { def constrained(tl: TypeLambda)(using Context): TypeLambda = constrained(tl, EmptyTree)._1 + /** Instantiate `tl` with fresh type variables added to the constraint. */ + def instantiateWithTypeVars(tl: TypeLambda)(using Context): Type = + val targs = constrained(tl, ast.tpd.EmptyTree, alwaysAddTypeVars = true)._2 + tl.instantiate(targs.tpes) + /** A new type variable with given bounds for its origin. * @param represents If exists, the TermParamRef that the TypeVar represents * in the substitution generated by `resultTypeApprox` @@ -707,7 +712,7 @@ object ProtoTypes { else mt.resultType /** The normalized form of a type - * - unwraps polymorphic types, tracking their parameters in the current constraint + * - instantiate polymorphic types with fresh type variables in the current constraint * - skips implicit parameters of methods and functions; * if result type depends on implicit parameter, replace with wildcard. * - converts non-dependent method types to the corresponding function types @@ -726,7 +731,7 @@ object ProtoTypes { Stats.record("normalize") tp.widenSingleton match { case poly: PolyType => - normalize(constrained(poly).resultType, pt) + normalize(instantiateWithTypeVars(poly), pt) case mt: MethodType => if (mt.isImplicitMethod) normalize(resultTypeApprox(mt, wildcardOnly = true), pt) else if (mt.isResultDependent) tp diff --git a/compiler/test/dotty/tools/dotc/core/ConstraintsTest.scala b/compiler/test/dotty/tools/dotc/core/ConstraintsTest.scala index 37ccf5667a7d..5ab162b9f05c 100644 --- a/compiler/test/dotty/tools/dotc/core/ConstraintsTest.scala +++ b/compiler/test/dotty/tools/dotc/core/ConstraintsTest.scala @@ -7,6 +7,7 @@ import dotty.tools.dotc.core.Contexts.{*, given} import dotty.tools.dotc.core.Decorators.{*, given} import dotty.tools.dotc.core.Symbols.* import dotty.tools.dotc.core.Types.* +import dotty.tools.dotc.ast.tpd.* import dotty.tools.dotc.typer.ProtoTypes.constrained import org.junit.Test @@ -18,8 +19,8 @@ class ConstraintsTest: @Test def mergeParamsTransitivity: Unit = inCompilerContext(TestConfiguration.basicClasspath, scalaSources = "trait A { def foo[S, T, R]: Any }") { - val tp = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda]) - val List(s, t, r) = tp.paramRefs + val tvars = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2 + val List(s, t, r) = tvars.tpes val innerCtx = ctx.fresh.setExploreTyperState() inContext(innerCtx) { @@ -37,8 +38,8 @@ class ConstraintsTest: @Test def mergeBoundsTransitivity: Unit = inCompilerContext(TestConfiguration.basicClasspath, scalaSources = "trait A { def foo[S, T]: Any }") { - val tp = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda]) - val List(s, t) = tp.paramRefs + val tvars = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2 + val List(s, t) = tvars.tpes val innerCtx = ctx.fresh.setExploreTyperState() inContext(innerCtx) { diff --git a/tests/pos/stm.scala b/tests/pos/stm.scala new file mode 100644 index 000000000000..48ff946f9b5c --- /dev/null +++ b/tests/pos/stm.scala @@ -0,0 +1,10 @@ +class Inv[X] +class Ref[X] +object Ref { + def apply(i: Inv[Int], x: Int): Ref[Int] = ??? + def apply[Y](i: Inv[Y], x: Y): Ref[Y] = ??? +} + +class A { + val ref: Ref[List[AnyRef]] = Ref(new Inv[List[AnyRef]], List.empty) +} diff --git a/tests/pos/t0851.scala b/tests/pos/t0851.scala index fdc504af75c5..c7393723b148 100644 --- a/tests/pos/t0851.scala +++ b/tests/pos/t0851.scala @@ -1,9 +1,8 @@ package test object test1 { - case class Foo[T,T2](f : (T,T2) => String) extends (((T,T2)) => String){ + case class Foo[T,T2](f : (T,T2) => String) { def apply(t : T) = (s:T2) => f(t,s) - def apply(p : (T,T2)) = f(p._1,p._2) } implicit def g[T](f : (T,String) => String): Foo[T, String] = Foo(f) def main(args : Array[String]) : Unit = { diff --git a/tests/pos/t2913.scala b/tests/pos/t2913.scala index f91ed7b51318..9d7b898cbe9d 100644 --- a/tests/pos/t2913.scala +++ b/tests/pos/t2913.scala @@ -33,9 +33,8 @@ object TestNoAutoTupling { // t0851 is essentially the same: object test1 { - case class Foo[T,T2](f : (T,T2) => String) extends (((T,T2)) => String){ + case class Foo[T,T2](f : (T,T2) => String) { def apply(t : T) = (s:T2) => f(t,s) - def apply(p : (T,T2)) = f(p._1,p._2) } implicit def g[T](f : (T,String) => String): test1.Foo[T,String] = Foo(f) def main(args : Array[String]) : Unit = {