diff --git a/compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala b/compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala index 7043169e2ae3..f5b44fc66911 100644 --- a/compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala +++ b/compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala @@ -16,7 +16,7 @@ import collection.mutable /** A utility class offering methods for rewriting inlined code */ class InlineReducer(inliner: Inliner)(using Context): import tpd.* - import Inliner.{isElideableExpr, DefBuffer} + import Inliner.{isElideableExpr, DefBuffer, inlinedConstToLiteral} import inliner.{call, newSym, tryInlineArg, paramBindingDef} extension (tp: Type) @@ -201,7 +201,7 @@ class InlineReducer(inliner: Inliner)(using Context): val copied = sym.copy(info = rhs.tpe.widenInlineScrutinee, coord = sym.coord, flags = sym.flags &~ Case).asTerm adjustErased(copied, rhs) - caseBindingMap += ((sym, ValDef(copied, constToLiteral(rhs)).withSpan(sym.span))) + caseBindingMap += ((sym, ValDef(copied, inlinedConstToLiteral(rhs)).withSpan(sym.span))) def newTypeBinding(sym: TypeSymbol, alias: Type): Unit = { val copied = sym.copy(info = TypeAlias(alias), coord = sym.coord).asType @@ -321,7 +321,7 @@ class InlineReducer(inliner: Inliner)(using Context): case (pat :: pats1, selector :: selectors1) => val elem = newSym(InlineBinderName.fresh(), Synthetic, selector.tpe.widenInlineScrutinee).asTerm adjustErased(elem, selector) - val rhs = constToLiteral(selector) + val rhs = inlinedConstToLiteral(selector) elem.defTree = rhs caseBindingMap += ((NoSymbol, ValDef(elem, rhs).withSpan(elem.span))) reducePattern(caseBindingMap, elem.termRef, pat) && @@ -337,7 +337,7 @@ class InlineReducer(inliner: Inliner)(using Context): else paramCls.asClass.paramAccessors val selectors = for (accessor <- caseAccessors) - yield constToLiteral(reduceProjection(ref(scrut).select(accessor).ensureApplied)) + yield inlinedConstToLiteral(reduceProjection(ref(scrut).select(accessor).ensureApplied)) caseAccessors.length == pats.length && reduceSubPatterns(pats, selectors) } else false diff --git a/compiler/src/dotty/tools/dotc/inlines/Inliner.scala b/compiler/src/dotty/tools/dotc/inlines/Inliner.scala index 356e5ad40fdd..e3f5c76c7783 100644 --- a/compiler/src/dotty/tools/dotc/inlines/Inliner.scala +++ b/compiler/src/dotty/tools/dotc/inlines/Inliner.scala @@ -183,6 +183,27 @@ object Inliner: end OpaqueProxy + /** A more powerful version of [[constToLiteral]] that also can "see through" + * [[Block]], [[Inlined]] and [[Typed]] trees that are elidable (see + * [[isElideableExpr]]). + */ + def inlinedConstToLiteral(rootTree: Tree)(using Context): Tree = + def rec(tree: Tree): Tree = + inline def recChild(subTree: Tree): Tree = + val res = rec(subTree) + if res eq subTree then tree else res + + tree match + case Typed(expr, _) => recChild(expr) + case Inlined(_, _, expr) => recChild(expr) + case Block(_, expr) => recChild(expr) + case _ => constToLiteral(tree) + + if isElideableExpr(rootTree) then + rec(rootTree) + else + constToLiteral(rootTree) + private[inlines] def newSym(name: Name, flags: FlagSet, info: Type, span: Span)(using Context): Symbol = newSymbol(ctx.owner, name, flags, info, coord = span) end Inliner @@ -897,7 +918,7 @@ class Inliner(val call: tpd.Tree)(using Context): //if the projection leads to a typed tree then we stop reduction resNoReduce else - val res = constToLiteral(reducedProjection) + val res = inlinedConstToLiteral(reducedProjection) if resNoReduce ne res then typed(res, pt) // redo typecheck if reduction changed something else if res.symbol.isInlineMethod then @@ -928,7 +949,7 @@ class Inliner(val call: tpd.Tree)(using Context): override def typedValDef(vdef: untpd.ValDef, sym: Symbol)(using Context): Tree = val vdef1 = if sym.is(Inline) then - val rhs = typed(vdef.rhs) + val rhs = inlinedConstToLiteral(typed(vdef.rhs)) sym.info = rhs.tpe untpd.cpy.ValDef(vdef)(vdef.name, untpd.TypeTree(rhs.tpe), untpd.TypedSplice(rhs)) else vdef @@ -936,11 +957,11 @@ class Inliner(val call: tpd.Tree)(using Context): override def typedApply(tree: untpd.Apply, pt: Type)(using Context): Tree = val locked = ctx.typerState.ownedVars - specializeEq(inlineIfNeeded(constToLiteral(BetaReduce(super.typedApply(tree, pt))), pt, locked)) + specializeEq(inlineIfNeeded(inlinedConstToLiteral(BetaReduce(super.typedApply(tree, pt))), pt, locked)) override def typedTypeApply(tree: untpd.TypeApply, pt: Type)(using Context): Tree = val locked = ctx.typerState.ownedVars - val tree1 = inlineIfNeeded(constToLiteral(BetaReduce(super.typedTypeApply(tree, pt))), pt, locked) + val tree1 = inlineIfNeeded(inlinedConstToLiteral(BetaReduce(super.typedTypeApply(tree, pt))), pt, locked) if tree1.symbol == defn.QuotedTypeModule_of then ctx.compilationUnit.needsStaging = true tree1 @@ -1021,8 +1042,8 @@ class Inliner(val call: tpd.Tree)(using Context): case _ => rhs0 } val rhs2 = rhs1 match { - case Typed(expr, tpt) if rhs1.span.isSynthetic => constToLiteral(expr) - case _ => constToLiteral(rhs1) + case Typed(expr, tpt) if rhs1.span.isSynthetic => inlinedConstToLiteral(expr) + case _ => inlinedConstToLiteral(rhs1) } val (usedBindings, rhs3) = dropUnusedDefs(caseBindings, rhs2) val rhs = seq(usedBindings, rhs3) @@ -1056,7 +1077,7 @@ class Inliner(val call: tpd.Tree)(using Context): val meth = tree.symbol if meth.isAllOf(DeferredInline) then errorTree(tree, em"Deferred inline ${meth.showLocated} cannot be invoked") - else if Inlines.needsInlining(tree) then Inlines.inlineCall(simplify(tree, pt, locked)) + else if Inlines.needsInlining(tree) then inlinedConstToLiteral(Inlines.inlineCall(simplify(tree, pt, locked))) else tree override def typedUnadapted(tree: untpd.Tree, pt: Type, locked: TypeVars)(using Context): Tree = diff --git a/docs/_docs/reference/metaprogramming/inline.md b/docs/_docs/reference/metaprogramming/inline.md index f4988d02e0ba..d913e890280a 100644 --- a/docs/_docs/reference/metaprogramming/inline.md +++ b/docs/_docs/reference/metaprogramming/inline.md @@ -247,8 +247,19 @@ trait InlineConstants: inline val myShort: Short object Constants extends InlineConstants: - inline val myShort/*: Short(4)*/ = 4 + inline val myShort/*: (4 : Short)*/ = 4 ``` + + +Inline values that are inside inline methods are only required to be constant _after inlining_. Therefore, the following is valid: + +```scala +inline def double(inline x: Int): Int = x * 2 +inline def eight: Int = + inline val res = double(4) + res +``` + ## Transparent Inline Methods diff --git a/tests/neg/i18123b.check b/tests/neg/i18123b.check new file mode 100644 index 000000000000..eb0fc99ea362 --- /dev/null +++ b/tests/neg/i18123b.check @@ -0,0 +1,9 @@ +-- [E007] Type Mismatch Error: tests/neg/i18123b.scala:8:8 ------------------------------------------------------------- +8 |def z = y.rep().toUpperCase // error + | ^^^^^^^ + | Found: (??? : => Nothing) + | Required: ?{ toUpperCase: ? } + | Note that implicit conversions were not tried because the result of an implicit conversion + | must be more specific than ?{ toUpperCase: } + | + | longer explanation available when compiling with `-explain` diff --git a/tests/neg/i18123b.scala b/tests/neg/i18123b.scala new file mode 100644 index 000000000000..03c31b8166f2 --- /dev/null +++ b/tests/neg/i18123b.scala @@ -0,0 +1,8 @@ +// Minimized version of `tests/pos/i18123.scala` to test #24425. + +extension (x: String) + transparent inline def rep(min: Int = 0): String = ??? + +def y: String = ??? + +def z = y.rep().toUpperCase // error diff --git a/tests/pos/i24412.scala b/tests/pos/i24412.scala new file mode 100644 index 000000000000..3b9ec8d5a579 --- /dev/null +++ b/tests/pos/i24412.scala @@ -0,0 +1,13 @@ +object test { + import scala.compiletime.erasedValue + + inline def contains[T <: Tuple, E]: Boolean = inline erasedValue[T] match { + case _: EmptyTuple => false + case _: (_ *: tail) => contains[tail, E] + } + inline def check[T <: Tuple]: Unit = { + inline if contains[T, Long] && false then ??? + } + + check[(String, Double)] +} diff --git a/tests/pos/i24421.scala b/tests/pos/i24421.scala new file mode 100644 index 000000000000..07b886afcdcb --- /dev/null +++ b/tests/pos/i24421.scala @@ -0,0 +1,22 @@ +import scala.language.experimental.erasedDefinitions +import scala.compiletime.{erasedValue, summonInline, error} + +inline def sizeTuple[T <: Tuple](): Long = + inline erasedValue[T] match + case _: EmptyTuple => 0 + case _: (h *: t) => size[h] + sizeTuple[t]() + +inline def sizeProduct[T](m: scala.deriving.Mirror.ProductOf[T]): Long = + sizeTuple[m.MirroredElemTypes]() + +inline def size[T]: Long = + inline erasedValue[T] match + case _: Char => 2 + case _: Int => 4 + case _: Long => 8 + case _: Double => 8 + case _: Product => sizeProduct(summonInline[scala.deriving.Mirror.ProductOf[T]]) + case _ => error(s"unsupported type") + +@main def Test = + assert(size[(Int, Long)] == 12) diff --git a/tests/pos/inline-val-in-inline-method.scala b/tests/pos/inline-val-in-inline-method.scala new file mode 100644 index 000000000000..e2c13be6814e --- /dev/null +++ b/tests/pos/inline-val-in-inline-method.scala @@ -0,0 +1,6 @@ +// Example in docs/_docs/reference/metaprogramming/inline.md + +inline def double(inline x: Int): Int = x * 2 +inline def eight: Int = + inline val res = double(4) + res diff --git a/tests/pos/inline-val-short.scala b/tests/pos/inline-val-short.scala new file mode 100644 index 000000000000..0bd37c83b13b --- /dev/null +++ b/tests/pos/inline-val-short.scala @@ -0,0 +1,7 @@ +// Example in docs/_docs/reference/metaprogramming/inline.md + +trait InlineConstants: + inline val myShort: Short + +object Constants extends InlineConstants: + inline val myShort/*: Short(4)*/ = 4 diff --git a/tests/run/i24420-inline-local-ref.scala b/tests/run/i24420-inline-local-ref.scala new file mode 100644 index 000000000000..84828de9f443 --- /dev/null +++ b/tests/run/i24420-inline-local-ref.scala @@ -0,0 +1,12 @@ +inline def f(): Long = + 1L + +inline def g(): Long = + inline val x = f() + x + +inline def h(): Long = + inline if g() > 0L then 1L else 0L + +@main def Test: Unit = + assert(h() == 1L) diff --git a/tests/run/i24420-inline-val.scala b/tests/run/i24420-inline-val.scala new file mode 100644 index 000000000000..ec201d948a96 --- /dev/null +++ b/tests/run/i24420-inline-val.scala @@ -0,0 +1,25 @@ +inline def f1(): Long = + 1L + +inline def f2(): Long = + inline val x = f1() + 1L + x + +inline def f3(): Long = + inline val x = f1() + x + +inline def g1(): Boolean = + true + +inline def g2(): Long = + inline if g1() then 1L else 2L + +inline def g3(): Long = + inline if f1() > 0L then 1L else 2L + +@main def Test: Unit = + assert(f2() == 2L) + assert(f3() == 1L) + assert(g2() == 1L) + assert(g3() == 1L) diff --git a/tests/run/i24420-transparent-inline-local-ref.scala b/tests/run/i24420-transparent-inline-local-ref.scala new file mode 100644 index 000000000000..7d8180680130 --- /dev/null +++ b/tests/run/i24420-transparent-inline-local-ref.scala @@ -0,0 +1,12 @@ +transparent inline def f(): Long = + 1L + +transparent inline def g(): Long = + inline val x = f() + x + +transparent inline def h(): Long = + inline if g() > 0L then 1L else 0L + +@main def Test: Unit = + assert(h() == 1L)