diff --git a/compiler/src/dotty/tools/dotc/config/JavaPlatform.scala b/compiler/src/dotty/tools/dotc/config/JavaPlatform.scala index 80cb9f556867..9e546838d6c9 100644 --- a/compiler/src/dotty/tools/dotc/config/JavaPlatform.scala +++ b/compiler/src/dotty/tools/dotc/config/JavaPlatform.scala @@ -50,7 +50,10 @@ class JavaPlatform extends Platform { cls.superClass == defn.ObjectClass && cls.directlyInheritedTraits.forall(_.is(NoInits)) && !ExplicitOuter.needsOuterIfReferenced(cls) && - cls.typeRef.fields.isEmpty // Superaccessors already show up as abstract methods here, so no test necessary + // Superaccessors already show up as abstract methods here, so no test necessary + cls.typeRef.fields.isEmpty && + // Check if the SAM can be implemented via LambdaMetaFactory + TypeErasure.samNotNeededExpansion(cls) /** We could get away with excluding BoxedBooleanClass for the * purpose of equality testing since it need not compare equal diff --git a/compiler/src/dotty/tools/dotc/core/TypeErasure.scala b/compiler/src/dotty/tools/dotc/core/TypeErasure.scala index c29b971f1a5a..79fb4242a3d3 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeErasure.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeErasure.scala @@ -74,7 +74,7 @@ end SourceLanguage * only for isInstanceOf, asInstanceOf: PolyType, TypeParamRef, TypeBounds * */ -object TypeErasure { +object TypeErasure: private def erasureDependsOnArgs(sym: Symbol)(using Context) = sym == defn.ArrayClass || sym == defn.PairClass || sym.isDerivedValueClass @@ -586,7 +586,102 @@ object TypeErasure { defn.FunctionType(n = info.nonErasedParamCount) } erasure(functionType(applyInfo)) -} + + /** Check if LambdaMetaFactory can handle signature adaptation between two method types. + * + * LMF has limitations on what type adaptations it can perform automatically. + * This method checks whether manual bridging is needed for params and/or result. + * + * The adaptation rules are: + * - For parameters: primitives and value classes cannot be auto-adapted by LMF + * because the Scala spec requires null to be "unboxed" to the default value, + * but LMF throws `NullPointerException` instead. + * - For results: value classes and Unit cannot be auto-adapted by LMF. + * Non-Unit primitives can be auto-adapted since LMF only needs to box (not unbox). + * - LMF cannot auto-adapt between Object and Array types. + * + * @param implParamTypes Parameter types of the implementation method + * @param implResultType Result type of the implementation method + * @param samParamTypes Parameter types of the SAM method + * @param samResultType Result type of the SAM method + * + * @return (paramNeeded, resultNeeded) indicating what needs bridging + */ + def additionalAdaptationNeeded( + implParamTypes: List[Type], + implResultType: Type, + samParamTypes: List[Type], + samResultType: Type + )(using Context): (paramNeeded: Boolean, resultNeeded: Boolean) = + def sameClass(tp1: Type, tp2: Type) = tp1.classSymbol == tp2.classSymbol + + /** Can the implementation parameter type `tp` be auto-adapted to a different + * parameter type in the SAM? + * + * For derived value classes, we always need to do the bridging manually. + * For primitives, we cannot rely on auto-adaptation on the JVM because + * the Scala spec requires null to be "unboxed" to the default value of + * the value class, but the adaptation performed by LambdaMetaFactory + * will throw a `NullPointerException` instead. + */ + def autoAdaptedParam(tp: Type) = !tp.isErasedValueType && !tp.isPrimitiveValueType + + /** Can the implementation result type be auto-adapted to a different result + * type in the SAM? + * + * For derived value classes, it's the same story as for parameters. + * For non-Unit primitives, we can actually rely on the `LambdaMetaFactory` + * adaptation, because it only needs to box, not unbox, so no special + * handling of null is required. + */ + def autoAdaptedResult(tp: Type) = + !tp.isErasedValueType && !(tp.classSymbol eq defn.UnitClass) + + val paramAdaptationNeeded = + implParamTypes.lazyZip(samParamTypes).exists((implType, samType) => + !sameClass(implType, samType) && (!autoAdaptedParam(implType) + // LambdaMetaFactory cannot auto-adapt between Object and Array types + || samType.isInstanceOf[JavaArrayType])) + + val resultAdaptationNeeded = + !sameClass(implResultType, samResultType) && !autoAdaptedResult(implResultType) + + (paramAdaptationNeeded, resultAdaptationNeeded) + end additionalAdaptationNeeded + + /** Check if LambdaMetaFactory can handle the SAM method's required signature adaptation. + * + * When a SAM method overrides other methods, the erased signatures must be compatible + * to be qualifies as a valid functional interface on JVM. + * This method returns true if all overridden methods have compatible erased signatures + * that LMF can auto-adapt (or don't need adaptation). + * + * When this returns true, the SAM class does not need to be expanded. + * + * @param cls The SAM class to check + * @return true if LMF can handle the required adaptation + */ + def samNotNeededExpansion(cls: ClassSymbol)(using Context): Boolean = cls.typeRef.possibleSamMethods match + case Seq(samMeth) => + val samMethSym = samMeth.symbol + val erasedSamInfo = transformInfo(samMethSym, samMeth.info) + + val (erasedSamParamTypes, erasedSamResultType) = erasedSamInfo match + case mt: MethodType => (mt.paramInfos, mt.resultType) + case _ => return false + + samMethSym.allOverriddenSymbols.forall { overridden => + val erasedOverriddenInfo = transformInfo(overridden, overridden.info) + erasedOverriddenInfo match + case mt: MethodType => + val (paramNeeded, resultNeeded) = + additionalAdaptationNeeded(erasedSamParamTypes, erasedSamResultType, mt.paramInfos, mt.resultType) + !(paramNeeded || resultNeeded) + case _ => true + } + case _ => false + end samNotNeededExpansion +end TypeErasure import TypeErasure.* diff --git a/compiler/src/dotty/tools/dotc/transform/Erasure.scala b/compiler/src/dotty/tools/dotc/transform/Erasure.scala index 8403a55b6b36..99cb74ecac7f 100644 --- a/compiler/src/dotty/tools/dotc/transform/Erasure.scala +++ b/compiler/src/dotty/tools/dotc/transform/Erasure.scala @@ -453,41 +453,9 @@ object Erasure { val samParamTypes = sam.paramInfos val samResultType = sam.resultType - /** Can the implementation parameter type `tp` be auto-adapted to a different - * parameter type in the SAM? - * - * For derived value classes, we always need to do the bridging manually. - * For primitives, we cannot rely on auto-adaptation on the JVM because - * the Scala spec requires null to be "unboxed" to the default value of - * the value class, but the adaptation performed by LambdaMetaFactory - * will throw a `NullPointerException` instead. See `lambda-null.scala` - * for test cases. - * - * @see [LambdaMetaFactory](https://docs.oracle.com/en/java/javase/17/docs/api/java.base/java/lang/invoke/LambdaMetafactory.html) - */ - def autoAdaptedParam(tp: Type) = - !tp.isErasedValueType && !tp.isPrimitiveValueType - - /** Can the implementation result type be auto-adapted to a different result - * type in the SAM? - * - * For derived value classes, it's the same story as for parameters. - * For non-Unit primitives, we can actually rely on the `LambdaMetaFactory` - * adaptation, because it only needs to box, not unbox, so no special - * handling of null is required. - */ - def autoAdaptedResult = - !implResultType.isErasedValueType && !implReturnsUnit - - def sameClass(tp1: Type, tp2: Type) = tp1.classSymbol == tp2.classSymbol - - val paramAdaptationNeeded = - implParamTypes.lazyZip(samParamTypes).exists((implType, samType) => - !sameClass(implType, samType) && (!autoAdaptedParam(implType) - // LambdaMetaFactory cannot auto-adapt between Object and Array types - || samType.isInstanceOf[JavaArrayType])) - val resultAdaptationNeeded = - !sameClass(implResultType, samResultType) && !autoAdaptedResult + // Check if bridging is needed using the common function from TypeErasure + val (paramAdaptationNeeded, resultAdaptationNeeded) = + additionalAdaptationNeeded(implParamTypes, implResultType, samParamTypes, samResultType) if paramAdaptationNeeded || resultAdaptationNeeded then // Instead of instantiating `scala.FunctionN`, see if we can instantiate diff --git a/tests/run/i24573.check b/tests/run/i24573.check new file mode 100644 index 000000000000..ae6671e527f9 --- /dev/null +++ b/tests/run/i24573.check @@ -0,0 +1,43 @@ +1 +2 +3 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +31 +32 +33 +34 +41 +42 +43 +44 +45 +46 +51 +52 +53 +55 +56 +57 +61 +62 +63 +64 +71 +72 +75 +76 +81 +82 diff --git a/tests/run/i24573.scala b/tests/run/i24573.scala new file mode 100644 index 000000000000..d58ace91c98b --- /dev/null +++ b/tests/run/i24573.scala @@ -0,0 +1,178 @@ +trait ConTU[-T] extends (T => Unit): + def apply(t: T): Unit + +trait ConTI[-T] extends (T => Int): + def apply(t: T): Int + +trait ConTS[-T] extends (T => String): + def apply(t: T): String + +trait ConIR[+R] extends (Int => R): + def apply(t: Int): R + +trait ConSR[+R] extends (String => R): + def apply(t: String): R + +trait ConUR[+R] extends (() => R): + def apply(): R + +trait ConII extends (Int => Int): + def apply(t: Int): Int + +trait ConSI extends (String => Int): + def apply(t: String): Int + +trait ConIS extends (Int => String): + def apply(t: Int): String + +trait ConUU extends (() => Unit): + def apply(): Unit + +trait F1[-T, +R]: + def apply(t: T): R + +trait SFTU[-T] extends F1[T, Unit]: + def apply(t: T): Unit + +trait SFTI[-T] extends F1[T, Int]: + def apply(t: T): Int + +trait SFTS[-T] extends F1[T, String]: + def apply(t: T): String + +trait SFIR [+R] extends F1[Int, R]: + def apply(t: Int): R + +trait SFSR [+R] extends F1[String, R]: + def apply(t: String): R + +trait SFII extends F1[Int, Int]: + def apply(t: Int): Int + +trait SFSI extends F1[String, Int]: + def apply(t: String): Int + +trait SFIS extends F1[Int, String]: + def apply(t: Int): String + +trait SFIU extends F1[Int, Unit]: + def apply(t: Int): Unit + +trait F1U[-T]: + def apply(t: T): Unit + +trait SF2T[-T] extends F1U[T]: + def apply(t: T): Unit + +trait SF2I extends F1U[Int]: + def apply(t: Int): Unit + +trait SF2S extends F1U[String]: + def apply(t: String): Unit + +object Test: + def main(args: Array[String]): Unit = + val fIU: (Int => Unit) = (x: Int) => println(x) // closure by JFunction1 + fIU(1) + + val fIS: (Int => String) = (x: Int) => x.toString // closure + println(fIS(2)) + + val fUI: (() => Int) = () => 3 // closure + println(fUI()) + + val conITU: ConTU[Int] = (x: Int) => println(x) // expanded + conITU(11) + val conITI: ConTI[Int] = (x: Int) => x // closure + println(conITI(12)) + val conITS: ConTS[Int] = (x: Int) => x.toString // closure + println(conITS(13)) + val conSTS: ConTS[String] = (x: String) => x // closure + println(conSTS("14")) + + val conIRS: ConIR[String] = (x: Int) => x.toString // expanded + println(conIRS(15)) + val conIRI: ConIR[Int] = (x: Int) => x // expanded + println(conIRI(16)) + val conIRU: ConIR[Unit] = (x: Int) => println(x) // expanded + conIRU(17) + + val conSRI: ConSR[Int] = (x: String) => x.toInt // closure + println(conSRI("18")) + val conURI: ConUR[Int] = () => 19 // closure + println(conURI()) + val conURU: ConUR[Unit] = () => println("20") // closure + conURU() + + val conII: ConII = (x: Int) => x // expanded + println(conII(21)) + val conSI: ConSI = (x: String) => x.toInt // closure + println(conSI("22")) + val conIS: ConIS = (x: Int) => x.toString // expanded + println(conIS(23)) + val conUU: ConUU = () => println("24") // expanded + conUU() + + val ffIU: F1[Int, Unit] = (x: Int) => println(x) // closure + ffIU(31) + val ffIS: F1[Int, String] = (x: Int) => x.toString // closure + println(ffIS(32)) + val ffSU: F1[String, Unit] = (x: String) => println(x) // closure + ffSU("33") + val ffSI: F1[String, Int] = (x: String) => x.toInt // closure + println(ffSI("34")) + + val sfITU: SFTU[Int] = (x: Int) => println(x) // expanded + sfITU(41) + val sfSTU: SFTU[String] = (x: String) => println(x) // expanded + sfSTU("42") + + val sfITI: SFTI[Int] = (x: Int) => x // closure + println(sfITI(43)) + val sfSTI: SFTI[String] = (x: String) => x.toInt // closure + println(sfSTI("44")) + + val sfITS: SFTS[Int] = (x: Int) => x.toString // closure + println(sfITS(45)) + val sfSTS: SFTS[String] = (x: String) => x // closure + println(sfSTS("46")) + + val sfIRI: SFIR[Int] = (x: Int) => x // expanded + println(sfIRI(51)) + val sfIRS: SFIR[String] = (x: Int) => x.toString // expanded + println(sfIRS(52)) + val sfIRU: SFIR[Unit] = (x: Int) => println(x) // expanded + sfIRU(53) + + val sfSRI: SFSR[Int] = (x: String) => x.toInt // closure + println(sfSRI("55")) + val sfSRS: SFSR[String] = (x: String) => x // closure + println(sfSRS("56")) + val sfSRU: SFSR[Unit] = (x: String) => println(x) // closure + sfSRU("57") + + val sfII: SFII = (x: Int) => x // expanded + println(sfII(61)) + val sfSI: SFSI = (x: String) => x.toInt // closure + println(sfSI("62")) + val sfIS: SFIS = (x: Int) => x.toString // expanded + println(sfIS(63)) + val sfIU: SFIU = (x: Int) => println(x) // expanded + sfIU(64) + + val f2ITU: F1U[Int] = (x: Int) => println(x) // closure + f2ITU(71) + val f2STU: F1U[String] = (x: String) => println(x) // closure + f2STU("72") + + val sf2IT: SF2T[Int] = (x: Int) => println(x) // closure + sf2IT(75) + val sf2ST: SF2T[String] = (x: String) => println(x) // closure + sf2ST("76") + + val sf2I: SF2I = (x: Int) => println(x) // expanded + sf2I(81) + val sf2S: SF2S = (x: String) => println(x) // closure + sf2S("82") + +end Test