Skip to content

Commit

Permalink
Check flags for newMethod, newVal and newBind (scala#16565)
Browse files Browse the repository at this point in the history
  • Loading branch information
smarter committed May 3, 2023
2 parents ebaceb8 + 3c3c0fb commit b8d2966
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 15 deletions.
46 changes: 34 additions & 12 deletions compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import dotty.tools.dotc.ast.untpd
import dotty.tools.dotc.core.Annotations
import dotty.tools.dotc.core.Contexts._
import dotty.tools.dotc.core.Decorators._
import dotty.tools.dotc.core.Flags._
import dotty.tools.dotc.core.NameKinds
import dotty.tools.dotc.core.NameOps._
import dotty.tools.dotc.core.StdNames._
Expand Down Expand Up @@ -276,12 +275,13 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler

object DefDef extends DefDefModule:
def apply(symbol: Symbol, rhsFn: List[List[Tree]] => Option[Term]): DefDef =
assert(symbol.isTerm, s"expected a term symbol but received $symbol")
xCheckMacroAssert(symbol.isTerm, s"expected a term symbol but received $symbol")
xCheckMacroAssert(symbol.flags.is(Flags.Method), "expected a symbol with `Method` flag set")
withDefaultPos(tpd.DefDef(symbol.asTerm, prefss =>
xCheckMacroedOwners(xCheckMacroValidExpr(rhsFn(prefss)), symbol).getOrElse(tpd.EmptyTree)
xCheckedMacroOwners(xCheckMacroValidExpr(rhsFn(prefss)), symbol).getOrElse(tpd.EmptyTree)
))
def copy(original: Tree)(name: String, paramss: List[ParamClause], tpt: TypeTree, rhs: Option[Term]): DefDef =
tpd.cpy.DefDef(original)(name.toTermName, paramss, tpt, xCheckMacroedOwners(rhs, original.symbol).getOrElse(tpd.EmptyTree))
tpd.cpy.DefDef(original)(name.toTermName, paramss, tpt, xCheckedMacroOwners(rhs, original.symbol).getOrElse(tpd.EmptyTree))
def unapply(ddef: DefDef): (String, List[ParamClause], TypeTree, Option[Term]) =
(ddef.name.toString, ddef.paramss, ddef.tpt, optional(ddef.rhs))
end DefDef
Expand All @@ -307,9 +307,10 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler

object ValDef extends ValDefModule:
def apply(symbol: Symbol, rhs: Option[Term]): ValDef =
withDefaultPos(tpd.ValDef(symbol.asTerm, xCheckMacroedOwners(xCheckMacroValidExpr(rhs), symbol).getOrElse(tpd.EmptyTree)))
xCheckMacroAssert(!symbol.flags.is(Flags.Method), "expected a symbol without `Method` flag set")
withDefaultPos(tpd.ValDef(symbol.asTerm, xCheckedMacroOwners(xCheckMacroValidExpr(rhs), symbol).getOrElse(tpd.EmptyTree)))
def copy(original: Tree)(name: String, tpt: TypeTree, rhs: Option[Term]): ValDef =
tpd.cpy.ValDef(original)(name.toTermName, tpt, xCheckMacroedOwners(xCheckMacroValidExpr(rhs), original.symbol).getOrElse(tpd.EmptyTree))
tpd.cpy.ValDef(original)(name.toTermName, tpt, xCheckedMacroOwners(xCheckMacroValidExpr(rhs), original.symbol).getOrElse(tpd.EmptyTree))
def unapply(vdef: ValDef): (String, TypeTree, Option[Term]) =
(vdef.name.toString, vdef.tpt, optional(vdef.rhs))

Expand Down Expand Up @@ -398,7 +399,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
def etaExpand(owner: Symbol): Term = self.tpe.widen match {
case mtpe: Types.MethodType if !mtpe.isParamDependent =>
val closureResType = mtpe.resType match {
case t: Types.MethodType => t.toFunctionType(isJava = self.symbol.is(JavaDefined))
case t: Types.MethodType => t.toFunctionType(isJava = self.symbol.is(dotc.core.Flags.JavaDefined))
case t => t
}
val closureTpe = Types.MethodType(mtpe.paramNames, mtpe.paramInfos, closureResType)
Expand Down Expand Up @@ -828,7 +829,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
object Lambda extends LambdaModule:
def apply(owner: Symbol, tpe: MethodType, rhsFn: (Symbol, List[Tree]) => Tree): Block =
val meth = dotc.core.Symbols.newAnonFun(owner, tpe)
withDefaultPos(tpd.Closure(meth, tss => xCheckMacroedOwners(xCheckMacroValidExpr(rhsFn(meth, tss.head.map(withDefaultPos))), meth)))
withDefaultPos(tpd.Closure(meth, tss => xCheckedMacroOwners(xCheckMacroValidExpr(rhsFn(meth, tss.head.map(withDefaultPos))), meth)))

def unapply(tree: Block): Option[(List[ValDef], Term)] = tree match {
case Block((ddef @ DefDef(_, tpd.ValDefs(params) :: Nil, _, Some(body))) :: Nil, Closure(meth, _))
Expand Down Expand Up @@ -1499,6 +1500,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler

object Bind extends BindModule:
def apply(sym: Symbol, pattern: Tree): Bind =
xCheckMacroAssert(sym.flags.is(Flags.Case), "expected a symbol with `Case` flag set")
withDefaultPos(tpd.Bind(sym, pattern))
def copy(original: Tree)(name: String, pattern: Tree): Bind =
withDefaultPos(tpd.cpy.Bind(original)(name.toTermName, pattern))
Expand Down Expand Up @@ -2539,14 +2541,23 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
newMethod(owner, name, tpe, Flags.EmptyFlags, noSymbol)
def newMethod(owner: Symbol, name: String, tpe: TypeRepr, flags: Flags, privateWithin: Symbol): Symbol =
assert(!privateWithin.exists || privateWithin.isType, "privateWithin must be a type symbol or `Symbol.noSymbol`")
checkValidFlags(flags.toTermFlags, Flags.validMethodFlags)
dotc.core.Symbols.newSymbol(owner, name.toTermName, flags | dotc.core.Flags.Method, tpe, privateWithin)
def newVal(owner: Symbol, name: String, tpe: TypeRepr, flags: Flags, privateWithin: Symbol): Symbol =
assert(!privateWithin.exists || privateWithin.isType, "privateWithin must be a type symbol or `Symbol.noSymbol`")
checkValidFlags(flags.toTermFlags, Flags.validValFlags)
dotc.core.Symbols.newSymbol(owner, name.toTermName, flags, tpe, privateWithin)
def newBind(owner: Symbol, name: String, flags: Flags, tpe: TypeRepr): Symbol =
dotc.core.Symbols.newSymbol(owner, name.toTermName, flags | Case, tpe)
checkValidFlags(flags.toTermFlags, Flags.validBindFlags)
dotc.core.Symbols.newSymbol(owner, name.toTermName, flags | dotc.core.Flags.Case, tpe)
def noSymbol: Symbol = dotc.core.Symbols.NoSymbol

private inline def checkValidFlags(inline flags: Flags, inline valid: Flags): Unit =
xCheckMacroAssert(
flags <= valid,
s"Received invalid flags. Expected flags ${flags.show} to only contain a subset of ${valid.show}."
)

def freshName(prefix: String): String =
NameKinds.MacroNames.fresh(prefix.toTermName).toString
end Symbol
Expand Down Expand Up @@ -2619,7 +2630,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
self.isTerm && !self.is(dotc.core.Flags.Method) && !self.is(dotc.core.Flags.Case/*, FIXME add this check and fix sourcecode butNot = Enum | Module*/)
def isDefDef: Boolean = self.is(dotc.core.Flags.Method)
def isBind: Boolean =
self.is(dotc.core.Flags.Case, butNot = Enum | Module) && !self.isClass
self.is(dotc.core.Flags.Case, butNot = dotc.core.Flags.Enum | dotc.core.Flags.Module) && !self.isClass
def isNoSymbol: Boolean = self == Symbol.noSymbol
def exists: Boolean = self != Symbol.noSymbol

Expand Down Expand Up @@ -2817,6 +2828,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler

object Flags extends FlagsModule:
def Abstract: Flags = dotc.core.Flags.Abstract
def AbsOverride: Flags = dotc.core.Flags.AbsOverride
def Artifact: Flags = dotc.core.Flags.Artifact
def Case: Flags = dotc.core.Flags.Case
def CaseAccessor: Flags = dotc.core.Flags.CaseAccessor
Expand Down Expand Up @@ -2862,6 +2874,13 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
def Synthetic: Flags = dotc.core.Flags.Synthetic
def Trait: Flags = dotc.core.Flags.Trait
def Transparent: Flags = dotc.core.Flags.Transparent

// Keep: aligned with Quotes's `newMethod` doc
private[QuotesImpl] def validMethodFlags: Flags = Private | Protected | Override | Deferred | Final | Method | Implicit | Given | Local | AbsOverride | JavaStatic // Flags that could be allowed: Synthetic | ExtensionMethod | Exported | Erased | Infix | Invisible
// Keep: aligned with Quotes's `newVal` doc
private[QuotesImpl] def validValFlags: Flags = Private | Protected | Override | Deferred | Final | Param | Implicit | Lazy | Mutable | Local | ParamAccessor | Module | Package | Case | CaseAccessor | Given | Enum | AbsOverride | JavaStatic // Flags that could be added: Synthetic | Erased | Invisible
// Keep: aligned with Quotes's `newBind` doc
private[QuotesImpl] def validBindFlags: Flags = Case // Flags that could be allowed: Implicit | Given | Erased
end Flags

given FlagsMethods: FlagsMethods with
Expand Down Expand Up @@ -2982,7 +3001,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
/** Checks that all definitions in this tree have the expected owner.
* Nested definitions are ignored and assumed to be correct by construction.
*/
private def xCheckMacroedOwners(tree: Option[Tree], owner: Symbol): tree.type =
private def xCheckedMacroOwners(tree: Option[Tree], owner: Symbol): tree.type =
if xCheckMacro then
tree match
case Some(tree) =>
Expand All @@ -2993,7 +3012,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
/** Checks that all definitions in this tree have the expected owner.
* Nested definitions are ignored and assumed to be correct by construction.
*/
private def xCheckMacroedOwners(tree: Tree, owner: Symbol): tree.type =
private def xCheckedMacroOwners(tree: Tree, owner: Symbol): tree.type =
if xCheckMacro then
xCheckMacroOwners(tree, owner)
tree
Expand Down Expand Up @@ -3064,6 +3083,9 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
"Reference to a method must be eta-expanded before it is used as an expression: " + term.show)
term

private inline def xCheckMacroAssert(inline cond: Boolean, inline msg: String): Unit =
assert(!xCheckMacro || cond, msg)

object Printer extends PrinterModule:

lazy val TreeCode: Printer[Tree] = new Printer[Tree]:
Expand Down
16 changes: 13 additions & 3 deletions library/src/scala/quoted/Quotes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3785,9 +3785,10 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
* @param parent The owner of the method
* @param name The name of the method
* @param tpe The type of the method (MethodType, PolyType, ByNameType)
* @param flags extra flags to with which the symbol should be constructed
* @param flags extra flags to with which the symbol should be constructed. `Method` flag will be added. Can be `Private | Protected | Override | Deferred | Final | Method | Implicit | Given | Local | JavaStatic`
* @param privateWithin the symbol within which this new method symbol should be private. May be noSymbol.
*/
// Keep: `flags` doc aligned with QuotesImpl's `validMethodFlags`
def newMethod(parent: Symbol, name: String, tpe: TypeRepr, flags: Flags, privateWithin: Symbol): Symbol

/** Generates a new val/var/lazy val symbol with the given parent, name and type.
Expand All @@ -3801,11 +3802,12 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
* @param parent The owner of the val/var/lazy val
* @param name The name of the val/var/lazy val
* @param tpe The type of the val/var/lazy val
* @param flags extra flags to with which the symbol should be constructed
* @param flags extra flags to with which the symbol should be constructed. Can be `Private | Protected | Override | Deferred | Final | Param | Implicit | Lazy | Mutable | Local | ParamAccessor | Module | Package | Case | CaseAccessor | Given | Enum | JavaStatic`
* @param privateWithin the symbol within which this new method symbol should be private. May be noSymbol.
* @note As a macro can only splice code into the point at which it is expanded, all generated symbols must be
* direct or indirect children of the reflection context's owner.
*/
// Keep: `flags` doc aligned with QuotesImpl's `validValFlags`
def newVal(parent: Symbol, name: String, tpe: TypeRepr, flags: Flags, privateWithin: Symbol): Symbol

/** Generates a pattern bind symbol with the given parent, name and type.
Expand All @@ -3816,11 +3818,12 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
*
* @param parent The owner of the binding
* @param name The name of the binding
* @param flags extra flags to with which the symbol should be constructed
* @param flags extra flags to with which the symbol should be constructed. `Case` flag will be added. Can be `Case`
* @param tpe The type of the binding
* @note As a macro can only splice code into the point at which it is expanded, all generated symbols must be
* direct or indirect children of the reflection context's owner.
*/
// Keep: `flags` doc aligned with QuotesImpl's `validBindFlags`
def newBind(parent: Symbol, name: String, flags: Flags, tpe: TypeRepr): Symbol

/** Definition not available */
Expand Down Expand Up @@ -4373,6 +4376,13 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
/** Is this symbol `abstract` */
def Abstract: Flags

/** Is this an abstract override method?
*
* This corresponds to a definition declared as "abstract override def" in the source.
* See https://stackoverflow.com/questions/23645172/why-is-abstract-override-required-not-override-alone-in-subtrait for examples.
*/
@experimental def AbsOverride: Flags

/** Is this generated by Scala compiler.
* Corresponds to ACC_SYNTHETIC in the JVM.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ val experimentalDefinitionInLibrary = Set(
//// New APIs: Quotes
// Should be stabilized in 3.4.0
"scala.quoted.Quotes.reflectModule.defnModule.FunctionClass",
"scala.quoted.Quotes.reflectModule.FlagsModule.AbsOverride",
// Can be stabilized in 3.4.0 (unsure) or later
"scala.quoted.Quotes.reflectModule.CompilationInfoModule.XmacroSettings",
"scala.quoted.Quotes.reflectModule.FlagsModule.JavaAnnotation",
Expand Down

0 comments on commit b8d2966

Please sign in to comment.