Skip to content

Commit

Permalink
Intrinsify constValueTuple and summonAll
Browse files Browse the repository at this point in the history
The new implementation instantiates the TupleN/TupleXXL classes directly.
This avoids the expensive construction of tuples using `*:`.

Fixes #15988
  • Loading branch information
nicolasstucki committed Jun 20, 2023
1 parent b614d84 commit 3e393a9
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 29 deletions.
19 changes: 19 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1514,6 +1514,25 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
}
}

/** Creates the tuple containing the elemets */
def tupleTree(elems: List[Tree])(using Context): Tree = {
val arity = elems.length
if arity == 0 then
ref(defn.EmptyTupleModule)
else if arity <= Definitions.MaxTupleArity then
// TupleN[elem1Tpe, ...](elem1, ...)
ref(defn.TupleType(arity).nn.typeSymbol.companionModule)
.select(nme.apply)
.appliedToTypes(elems.map(_.tpe.widenIfUnstable))
.appliedToArgs(elems)
else
// TupleXXL.apply(elems*) // TODO add and use Tuple.apply(elems*) ?
ref(defn.TupleXXLModule)
.select(nme.apply)
.appliedToVarargs(elems.map(_.asInstance(defn.ObjectType)), TypeTree(defn.ObjectType))
.asInstance(defn.tupleType(elems.map(elem => elem.tpe.widenIfUnstable)))
}

/** Creates the tuple type tree representation of the type trees in `ts` */
def tupleTypeTree(elems: List[Tree])(using Context): Tree = {
val arity = elems.length
Expand Down
6 changes: 5 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,10 @@ class Definitions {
@tu lazy val Compiletime_requireConst : Symbol = CompiletimePackageClass.requiredMethod("requireConst")
@tu lazy val Compiletime_constValue : Symbol = CompiletimePackageClass.requiredMethod("constValue")
@tu lazy val Compiletime_constValueOpt: Symbol = CompiletimePackageClass.requiredMethod("constValueOpt")
@tu lazy val Compiletime_constValueTuple: Symbol = CompiletimePackageClass.requiredMethod("constValueTuple")
@tu lazy val Compiletime_summonFrom : Symbol = CompiletimePackageClass.requiredMethod("summonFrom")
@tu lazy val Compiletime_summonInline : Symbol = CompiletimePackageClass.requiredMethod("summonInline")
@tu lazy val Compiletime_summonInline : Symbol = CompiletimePackageClass.requiredMethod("summonInline")
@tu lazy val Compiletime_summonAll : Symbol = CompiletimePackageClass.requiredMethod("summonAll")
@tu lazy val CompiletimeTestingPackage: Symbol = requiredPackage("scala.compiletime.testing")
@tu lazy val CompiletimeTesting_typeChecks: Symbol = CompiletimeTestingPackage.requiredMethod("typeChecks")
@tu lazy val CompiletimeTesting_typeCheckErrors: Symbol = CompiletimeTestingPackage.requiredMethod("typeCheckErrors")
Expand Down Expand Up @@ -932,6 +934,8 @@ class Definitions {
@tu lazy val TupleTypeRef: TypeRef = requiredClassRef("scala.Tuple")
def TupleClass(using Context): ClassSymbol = TupleTypeRef.symbol.asClass
@tu lazy val Tuple_cons: Symbol = TupleClass.requiredMethod("*:")
@tu lazy val TupleModule: Symbol = requiredModule("scala.Tuple")
@tu lazy val Tuple_fromArray: Symbol = TupleModule.requiredMethod("fromArray")
@tu lazy val EmptyTupleModule: Symbol = requiredModule("scala.EmptyTuple")
@tu lazy val NonEmptyTupleTypeRef: TypeRef = requiredClassRef("scala.NonEmptyTuple")
def NonEmptyTupleClass(using Context): ClassSymbol = NonEmptyTupleTypeRef.symbol.asClass
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/inlines/Inliner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,8 @@ class Inliner(val call: tpd.Tree)(using Context):
// assertAllPositioned(tree) // debug
tree.changeOwner(originalOwner, ctx.owner)

def tryConstValue: Tree =
TypeComparer.constValue(callTypeArgs.head.tpe) match {
def tryConstValue(tpe: Type): Tree =
TypeComparer.constValue(tpe) match {
case Some(c) => Literal(c).withSpan(call.span)
case _ => EmptyTree
}
Expand Down
61 changes: 42 additions & 19 deletions compiler/src/dotty/tools/dotc/inlines/Inlines.scala
Original file line number Diff line number Diff line change
Expand Up @@ -408,36 +408,59 @@ object Inlines:
return Intrinsics.codeOf(arg, call.srcPos)
case _ =>

// Special handling of `constValue[T]`, `constValueOpt[T], and summonInline[T]`
// Special handling of `constValue[T]`, `constValueOpt[T]`, `constValueTuple[T]`, and `summonInline[T]`
if callTypeArgs.length == 1 then
if (inlinedMethod == defn.Compiletime_constValue) {
val constVal = tryConstValue

def constValueOrError(tpe: Type): Tree =
val constVal = tryConstValue(tpe)
if constVal.isEmpty then
val msg = NotConstant("cannot take constValue", callTypeArgs.head.tpe)
return ref(defn.Predef_undefined).withSpan(call.span).withType(ErrorType(msg))
val msg = NotConstant("cannot take constValue", tpe)
ref(defn.Predef_undefined).withSpan(callTypeArgs.head.span).withType(ErrorType(msg))
else
return constVal
constVal

def searchImplicitOrError(tpe: Type): Tree =
val evTyper = new Typer(ctx.nestingLevel + 1)
val evCtx = ctx.fresh.setTyper(evTyper)
inContext(evCtx) {
val evidence = evTyper.inferImplicitArg(tpe, callTypeArgs.head.span)
evidence.tpe match
case fail: Implicits.SearchFailureType =>
errorTree(call, evTyper.missingArgMsg(evidence, tpe, ""))
case _ =>
evidence
}

def unrollTupleTypes(tpe: Type): List[Type] = tpe match
case AppliedType(tycon, args) if defn.isTupleClass(tycon.typeSymbol) =>
args
case AppliedType(tycon, head :: tail :: Nil) if tycon.isRef(defn.PairClass) =>
head :: unrollTupleTypes(tail)
case tpe: TermRef if tpe.symbol == defn.EmptyTupleModule =>
Nil

if (inlinedMethod == defn.Compiletime_constValue) {
return constValueOrError(callTypeArgs.head.tpe)
}
else if (inlinedMethod == defn.Compiletime_constValueOpt) {
val constVal = tryConstValue
val constVal = tryConstValue(callTypeArgs.head.tpe)
return (
if (constVal.isEmpty) ref(defn.NoneModule.termRef)
else New(defn.SomeClass.typeRef.appliedTo(constVal.tpe), constVal :: Nil)
)
}
else if (inlinedMethod == defn.Compiletime_constValueTuple) {
val types = unrollTupleTypes(callTypeArgs.head.tpe.dealias)
val constants = types.map(constValueOrError)
return Typed(tpd.tupleTree(constants), TypeTree(callTypeArgs.head.tpe)).withSpan(call.span)
}
else if (inlinedMethod == defn.Compiletime_summonInline) {
def searchImplicit(tpt: Tree) =
val evTyper = new Typer(ctx.nestingLevel + 1)
val evCtx = ctx.fresh.setTyper(evTyper)
inContext(evCtx) {
val evidence = evTyper.inferImplicitArg(tpt.tpe, tpt.span)
evidence.tpe match
case fail: Implicits.SearchFailureType =>
errorTree(call, evTyper.missingArgMsg(evidence, tpt.tpe, ""))
case _ =>
evidence
}
return searchImplicit(callTypeArgs.head)
return searchImplicitOrError(callTypeArgs.head.tpe)
}
else if (inlinedMethod == defn.Compiletime_summonAll) {
val types = unrollTupleTypes(callTypeArgs.head.tpe.dealias)
val implicits = types.map(searchImplicitOrError)
return Typed(tpd.tupleTree(implicits), TypeTree(callTypeArgs.head.tpe)).withSpan(call.span)
}
end if

Expand Down
10 changes: 3 additions & 7 deletions library/src/scala/compiletime/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,9 @@ transparent inline def constValue[T]: T =
* `(constValue[X1], ..., constValue[Xn])`.
*/
inline def constValueTuple[T <: Tuple]: T =
val res =
inline erasedValue[T] match
case _: EmptyTuple => EmptyTuple
case _: (t *: ts) => constValue[t] *: constValueTuple[ts]
end match
res.asInstanceOf[T]
end constValueTuple
// implemented in dotty.tools.dotc.typer.Inliner
error("Compiler bug: `constValueTuple` was not evaluated by the compiler")


/** Summons first given matching one of the listed cases. E.g. in
*
Expand Down
6 changes: 6 additions & 0 deletions tests/run/i15988a.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import scala.compiletime.constValueTuple

@main def Test: Unit =
assert(constValueTuple[EmptyTuple] == EmptyTuple)
assert(constValueTuple[("foo", 5, 3.14, "bar", false)] == ("foo", 5, 3.14, "bar", false))
assert(constValueTuple[(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23)] == (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23))

0 comments on commit 3e393a9

Please sign in to comment.