Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ast: allow replacing multiple fields with a single #2982

Merged
merged 1 commit into from
Dec 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ class AstNamerMacros(val c: Context) extends Reflection with CommonNamerMacros {
// step 1a: identify modified fields of the class
val versionedParams = if (isQuasi) Nil else getVersionedParams(params, stats)
val paramsVersions = versionedParams.flatMap(_.getVersions).distinct.sorted(versionOrdering)
val replacedFields = versionedParams.flatMap(
_.replaced.map { case (version, field) => version -> field.oldDef }
)
val replacedFields = versionedParams.flatMap(_.replaced.flatMap { case (version, field) =>
field.oldDefs.map { case (oldDef, _) => version -> oldDef }
})
def paramsForVersion(v: Version): List[ValDef] =
positionVersionedParams(versionedParams.flatMap(_.getApplyDeclDefnBefore(v)._1))

// step 2: validate the body of the class

Expand Down Expand Up @@ -217,9 +219,9 @@ class AstNamerMacros(val c: Context) extends Reflection with CommonNamerMacros {
addCopy(fullCopyParams)
} else {
// add primary copy with default values
val defaultCopyParams = versionedParams.map { vp =>
getCopyParamWithDefault(vp.getDefaultCopyDef())
}
val defaultCopyParams =
positionVersionedParams(versionedParams.flatMap(_.getDefaultCopyDef()))
.map(getCopyParamWithDefault)
addCopy(defaultCopyParams)

val defaultCopyParamNames = defaultCopyParams.map(_.name.toString).toSet
Expand All @@ -232,7 +234,7 @@ class AstNamerMacros(val c: Context) extends Reflection with CommonNamerMacros {
addCopy(fullCopyParamsNoDefaults)
// add secondary copy
paramsVersions.foreach { version =>
val copyParams = versionedParams.flatMap(_.getApplyDeclDefnBefore(version)._1)
val copyParams = paramsForVersion(version)
if (copyParams.length != defaultCopyParams.length || !allInDefaults(copyParams))
addCopy(copyParams, getDeprecatedAnno(version))
}
Expand Down Expand Up @@ -350,14 +352,14 @@ class AstNamerMacros(val c: Context) extends Reflection with CommonNamerMacros {
// with field A, B and additional binary compat ones C, D and E, we generate:
// apply(A, B, C), apply(A, B, C, D), apply(A, B, C, D, E)
paramsVersions.foreach { v =>
val applyParamsBuilder = List.newBuilder[ValDef]
val applyParamsBuilder = List.newBuilder[(ValDef, Int)]
val applyCastBuilder = List.newBuilder[ValDef]
versionedParams.foreach { vp =>
val (decl, defn) = vp.getApplyDeclDefnBefore(v)
decl.foreach(applyParamsBuilder += _)
defn.foreach(applyCastBuilder += _)
}
val params = applyParamsBuilder.result()
val params = positionVersionedParams(applyParamsBuilder.result())
val castFields = applyCastBuilder.result()
val anno = getDeprecatedAnno(v)
mstats1 += q"""
Expand Down Expand Up @@ -386,8 +388,7 @@ class AstNamerMacros(val c: Context) extends Reflection with CommonNamerMacros {
mstats1 += paramsVersions.headOption.fold {
getUnapply(params)
} { ver =>
val unapplyParams = versionedParams.flatMap(_.getApplyDeclDefnBefore(ver)._1)
getUnapply(unapplyParams, getDeprecatedAnno(ver))
getUnapply(paramsForVersion(ver), getDeprecatedAnno(ver))
}
mstats2 += getUnapply(params)
} else {
Expand Down Expand Up @@ -500,44 +501,69 @@ class AstNamerMacros(val c: Context) extends Reflection with CommonNamerMacros {
private class VersionedParam(
val param: ValDef,
val appended: Option[Version],
val replaced: Option[(Version, ReplacedField)]
val replaced: Map[Version, ReplacedField]
) {
for {
aver <- appended; (rver, rfield) <- replaced; if versionOrdering.lteq(rver, aver)
} yield c.abort(
param.pos,
s"${versionToString(aver)} [@newField for ${param.name}] must must precede " +
s"${versionToString(rver)} [@replacedField for ${rfield.oldDef.name}]"
)

def getVersions: Iterable[Version] = appended ++ replaced.map(_._1)
def getApplyDeclDefnBefore(version: Version): (Option[ValDef], Option[ValDef]) = {
(appended, replaced) match {
case (Some(aver), _) if versionOrdering.lteq(version, aver) =>
(None, Some(asValDefn(param)))
case (_, Some((rver, rfield))) if versionOrdering.lteq(version, rver) =>
(Some(asValDecl(rfield.oldDef)), Some(rfield.newValDefn))
case _ =>
(Some(asValDecl(param)), None)
appended.foreach { aver =>
replaced.foreach { case (rver, rfield) =>
if (versionOrdering.lteq(rver, aver)) {
val oldDef = rfield.oldDefs.head._1
c.abort(
param.pos,
s"${versionToString(aver)} [@newField for ${param.name}] must must precede " +
s"${versionToString(rver)} [@replacedField for ${oldDef.name}]"
)
}
}
}
def getDefaultCopyDef(): ValOrDefDef = {
replaced match {
case Some((_, rfield)) => rfield.oldDef
case _ => param

def getVersions: Iterable[Version] = appended ++ replaced.keys
def getApplyDeclDefnBefore(version: Version): (List[(ValDef, Int)], Option[ValDef]) = {
def checkVersion(ver: Version): Boolean = versionOrdering.lteq(version, ver)
if (appended.exists(checkVersion)) (Nil, Some(asValDefn(param)))
else {
val records = replaced.iterator.filter(x => checkVersion(x._1))
if (records.isEmpty) (asValDecl(param) -> -1 :: Nil, None)
else {
val rfield = records.minBy(_._1)(versionOrdering)._2
val decls = rfield.oldDefs.map { case (oldDef, pos) => asValDecl(oldDef) -> pos }
(decls, Some(rfield.newValDefn))
}
}
}
def getDefaultCopyDef(): List[(ValOrDefDef, Int)] = {
if (replaced.isEmpty) (param, -1) :: Nil
else replaced.minBy(_._1)(versionOrdering)._2.oldDefs
}
}

private def positionVersionedParams[A](params: List[(A, Int)]): List[A] = {
val res = new ListBuffer[A]
val paramIter = params.iterator.filter(_._2 < 0)
@tailrec
def iter(withPositions: List[(A, Int)]): Unit = {
withPositions match {
case (v, pos) :: rest =>
paramIter.take(pos - res.length).foreach { case (x, _) => res += x }
res += v
iter(rest)
case _ =>
paramIter.foreach { case (x, _) => res += x }
}
}
iter(params.filter(_._2 >= 0).sortBy(_._2))
res.toList
}

private def getVersionedParams(
params: List[ValDef],
stats: List[Tree]
): List[VersionedParam] = {
val appendedFields: Map[String, Version] = getNewFieldVersions(params)
val replacedFields: Map[String, (Version, ReplacedField)] = ReplacedField.getMap(params, stats)
val replacedFields: Map[String, Map[Version, ReplacedField]] =
ReplacedField.getMap(params, stats)
params.map { p =>
val pname = p.name.toString
new VersionedParam(p, appendedFields.get(pname), replacedFields.get(pname))
new VersionedParam(p, appendedFields.get(pname), replacedFields.getOrElse(pname, Map.empty))
}
}

Expand Down Expand Up @@ -584,9 +610,13 @@ class AstNamerMacros(val c: Context) extends Reflection with CommonNamerMacros {
builder.result()
}

private class ReplacedField(val oldDef: ValOrDefDef, val newVal: ValDef, ctor: Tree) {
private class ReplacedField(
val newVal: ValDef,
ctor: Tree,
val oldDefs: List[(ValOrDefDef, Int)]
) {
def newValDefn: ValDef = {
val body =
def bodyForSingleOldDef(oldDef: ValOrDefDef) =
if (ctor eq null)
q"""
import scala.meta.trees._
Expand All @@ -596,6 +626,19 @@ class AstNamerMacros(val c: Context) extends Reflection with CommonNamerMacros {
q"""
$ctor(${oldDef.name})
"""
def bodyForMultipleOldDefs = {
if (ctor eq null) c.abort(newVal.pos, s"${newVal.name} must define a ctor")
val names = oldDefs.map { case (oldDef, _) =>
val name = q"${oldDef.name}"
val arg = AssignOrNamedArg(name, name)
q"$arg"
}
q"$ctor(..$names)"
}
val body = oldDefs match {
case (oldDef, _) :: Nil => bodyForSingleOldDef(oldDef)
case _ => bodyForMultipleOldDefs
}
q"""
val ${newVal.name}: ${deannotateType(newVal)} = {
..$body
Expand All @@ -605,27 +648,45 @@ class AstNamerMacros(val c: Context) extends Reflection with CommonNamerMacros {
}

private object ReplacedField {
def getMap(params: List[ValDef], stats: List[Tree]): Map[String, (Version, ReplacedField)] = {
val fields: Map[String, ValDef] = params.map(p => p.name.toString -> p).toMap
stats.flatMap {
def getMap(
params: List[ValDef],
stats: List[Tree]
): Map[String, Map[Version, ReplacedField]] = {
val fields: Map[String, (ValDef, Map[Version, Tree])] = params.map { p =>
val ctorsByVersion = p.mods.annotations.collect {
case q"new replacesFields($since, $ctor)" =>
val version = parseVersionAnnot(since, "replacesFields", "since")
version -> ctor
}.toMap
p.name.toString -> (p, ctorsByVersion)
}.toMap
val replacedFields = stats.flatMap {
case p: ValOrDefDef =>
val anno = p.mods.annotations.collectFirst {
case q"new replacedField($until)" => (until, null)
case q"new replacedField($until, $ctor)" => (until, ctor)
case q"new replacedField($until)" => (until, -1)
case q"new replacedField($until, $pos)" =>
(until, getAnnotAttribute(pos).toInt)
}
anno.map { case (until, ctor) =>
anno.map { case (until, pos) =>
if (!p.mods.hasFlag(Flag.FINAL))
c.abort(p.pos, "replacedField-annotated fields must be final")
val version = parseVersionAnnot(until, "replacedField", "until")
val newField = getNewField(p)
val newVal = fields.getOrElse(
newField,
c.abort(p.pos, s"@replacedField: field `$newField` is undefined)")
)
newVal.name.toString -> (version, new ReplacedField(p, newVal, ctor))
newField -> (p, version, pos)
}
case _ => None
}.toMap
}
replacedFields.groupBy(_._1).map { case (k, v) =>
val (newVal, ctorsByVersion) = fields.getOrElse(
k,
c.abort(v.head._2._1.pos, s"@replacedField: field `$k` is undefined)")
)
k -> v.map(_._2).groupBy(_._2).map { case (ver, oldFields) =>
val ctor = ctorsByVersion.get(ver).orNull
val oldDefs = oldFields.map { case (oldField, _, pos) => (oldField, pos) }
ver -> new ReplacedField(newVal, ctor, oldDefs)
}
}
}

private def getNewField(oldDef: ValOrDefDef): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ object Metadata {
class branch extends StaticAnnotation
class astClass extends StaticAnnotation
class newField(since: String) extends StaticAnnotation
class replacedField(until: String, ctor: Any = null) extends StaticAnnotation
class replacedField(until: String, pos: Int = -1) extends StaticAnnotation
class replacesFields(since: String, ctor: Any) extends StaticAnnotation
class astCompanion extends StaticAnnotation
@getter class astField extends StaticAnnotation
@getter class auxiliary extends StaticAnnotation
Expand Down
24 changes: 13 additions & 11 deletions scalameta/trees/shared/src/main/scala/scala/meta/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import scala.meta.inputs._
import scala.meta.tokens._
import scala.meta.prettyprinters._
import scala.meta.internal.trees._
import scala.meta.internal.trees.Metadata.{newField, replacedField}
import scala.meta.internal.trees.Metadata.{newField, replacedField, replacesFields}
import scala.{meta => sm}

@root trait Tree extends InternalTree {
Expand Down Expand Up @@ -253,29 +253,31 @@ object Type {
}

@ast class FuncParamClause(values: List[Type]) extends Tree with Member.SyntaxValuesClause
object FunctionType {
private[meta] final val paramsToClause: List[Type] => FuncParamClause =
FuncParamClause.apply
}
@branch trait FunctionType extends Type with Tree.WithBody {
def paramClause: FuncParamClause
@deprecated("Please use paramClause instead", "4.6.0")
def params: List[Type]
def res: Type
override def body: Tree = res
}
@ast class Function(paramClause: FuncParamClause, res: Type) extends FunctionType {
@replacedField("4.6.0", FunctionType.paramsToClause) final override def params: List[Type] =
paramClause.values
@ast class Function(
@replacesFields("4.6.0", FuncParamClause)
paramClause: FuncParamClause,
res: Type
) extends FunctionType {
@replacedField("4.6.0") final override def params: List[Type] = paramClause.values
}
@ast class PolyFunction(tparamClause: ParamClause, tpe: Type)
extends Type with Tree.WithTParamClause with Tree.WithBody {
@replacedField("4.6.0") final def tparams: List[Param] = tparamClause.values
override final def body: Tree = tpe
}
@ast class ContextFunction(paramClause: FuncParamClause, res: Type) extends FunctionType {
@replacedField("4.6.0", FunctionType.paramsToClause) final override def params: List[Type] =
paramClause.values
@ast class ContextFunction(
@replacesFields("4.6.0", FuncParamClause)
paramClause: FuncParamClause,
res: Type
) extends FunctionType {
@replacedField("4.6.0") final override def params: List[Type] = paramClause.values
}
@ast @deprecated("Implicit functions are not supported in any dialect")
class ImplicitFunction(
Expand Down