Skip to content

Commit

Permalink
ast: allow replacing multiple fields with a single
Browse files Browse the repository at this point in the history
  • Loading branch information
kitbellew committed Dec 1, 2022
1 parent e66dc2f commit a49e17a
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ class AstNamerMacros(val c: Context) extends Reflection with CommonNamerMacros {
// step 1a: identify modified fields of the class
val versionedParams = if (isQuasi) Nil else getVersionedParams(params, stats)
val paramsVersions = versionedParams.flatMap(_.getVersions).distinct.sorted(versionOrdering)
val replacedFields = versionedParams.flatMap(
_.replaced.map { case (version, field) => version -> field.oldDef }
)
val replacedFields = versionedParams.flatMap(_.replaced.flatMap { case (version, field) =>
field.oldDefs.map { case (oldDef, _) => version -> oldDef }
})
def paramsForVersion(v: Version): List[ValDef] =
positionVersionedParams(versionedParams.flatMap(_.getApplyDeclDefnBefore(v)._1))

// step 2: validate the body of the class

Expand Down Expand Up @@ -217,9 +219,9 @@ class AstNamerMacros(val c: Context) extends Reflection with CommonNamerMacros {
addCopy(fullCopyParams)
} else {
// add primary copy with default values
val defaultCopyParams = versionedParams.map { vp =>
getCopyParamWithDefault(vp.getDefaultCopyDef())
}
val defaultCopyParams =
positionVersionedParams(versionedParams.flatMap(_.getDefaultCopyDef()))
.map(getCopyParamWithDefault)
addCopy(defaultCopyParams)

val defaultCopyParamNames = defaultCopyParams.map(_.name.toString).toSet
Expand All @@ -232,7 +234,7 @@ class AstNamerMacros(val c: Context) extends Reflection with CommonNamerMacros {
addCopy(fullCopyParamsNoDefaults)
// add secondary copy
paramsVersions.foreach { version =>
val copyParams = versionedParams.flatMap(_.getApplyDeclDefnBefore(version)._1)
val copyParams = paramsForVersion(version)
if (copyParams.length != defaultCopyParams.length || !allInDefaults(copyParams))
addCopy(copyParams, getDeprecatedAnno(version))
}
Expand Down Expand Up @@ -350,14 +352,14 @@ class AstNamerMacros(val c: Context) extends Reflection with CommonNamerMacros {
// with field A, B and additional binary compat ones C, D and E, we generate:
// apply(A, B, C), apply(A, B, C, D), apply(A, B, C, D, E)
paramsVersions.foreach { v =>
val applyParamsBuilder = List.newBuilder[ValDef]
val applyParamsBuilder = List.newBuilder[(ValDef, Int)]
val applyCastBuilder = List.newBuilder[ValDef]
versionedParams.foreach { vp =>
val (decl, defn) = vp.getApplyDeclDefnBefore(v)
decl.foreach(applyParamsBuilder += _)
defn.foreach(applyCastBuilder += _)
}
val params = applyParamsBuilder.result()
val params = positionVersionedParams(applyParamsBuilder.result())
val castFields = applyCastBuilder.result()
val anno = getDeprecatedAnno(v)
mstats1 += q"""
Expand Down Expand Up @@ -386,8 +388,7 @@ class AstNamerMacros(val c: Context) extends Reflection with CommonNamerMacros {
mstats1 += paramsVersions.headOption.fold {
getUnapply(params)
} { ver =>
val unapplyParams = versionedParams.flatMap(_.getApplyDeclDefnBefore(ver)._1)
getUnapply(unapplyParams, getDeprecatedAnno(ver))
getUnapply(paramsForVersion(ver), getDeprecatedAnno(ver))
}
mstats2 += getUnapply(params)
} else {
Expand Down Expand Up @@ -502,44 +503,71 @@ class AstNamerMacros(val c: Context) extends Reflection with CommonNamerMacros {
private class VersionedParam(
val param: ValDef,
val appended: Option[Version],
val replaced: Option[(Version, ReplacedField)]
val replaced: Map[Version, ReplacedField]
) {
for {
aver <- appended; (rver, rfield) <- replaced; if versionOrdering.lteq(rver, aver)
} yield c.abort(
param.pos,
s"${versionToString(aver)} [@newField for ${param.name}] must must precede " +
s"${versionToString(rver)} [@replacedField for ${rfield.oldDef.name}]"
)

def getVersions: Iterable[Version] = appended ++ replaced.map(_._1)
def getApplyDeclDefnBefore(version: Version): (Option[ValDef], Option[ValDef]) = {
(appended, replaced) match {
case (Some(aver), _) if versionOrdering.lteq(version, aver) =>
(None, Some(asValDefn(param)))
case (_, Some((rver, rfield))) if versionOrdering.lteq(version, rver) =>
(Some(asValDecl(rfield.oldDef)), Some(rfield.newValDefn))
appended.foreach { aver =>
replaced.foreach { case (rver, rfield) =>
if (versionOrdering.lteq(rver, aver)) {
val oldDef = rfield.oldDefs.head._1
c.abort(
param.pos,
s"${versionToString(aver)} [@newField for ${param.name}] must must precede " +
s"${versionToString(rver)} [@replacedField for ${oldDef.name}]"
)
}
}
}

def getVersions: Iterable[Version] = appended ++ replaced.keys
def getApplyDeclDefnBefore(version: Version): (List[(ValDef, Int)], Option[ValDef]) = {
appended match {
case Some(aver) if versionOrdering.lteq(version, aver) =>
(Nil, Some(asValDefn(param)))
case _ =>
(Some(asValDecl(param)), None)
val records = replaced.iterator.filter(x => versionOrdering.gteq(version, x._1))
(if (records.isEmpty) None else Some(records.maxBy(_._1)(versionOrdering))) match {
case Some((_, rfield)) =>
val decls = rfield.oldDefs.map { case (oldDef, pos) => asValDecl(oldDef) -> pos }
(decls, Some(rfield.newValDefn))
case _ =>
(asValDecl(param) -> -1 :: Nil, None)
}
}
}
def getDefaultCopyDef(): ValOrDefDef = {
replaced match {
case Some((_, rfield)) => rfield.oldDef
case _ => param
def getDefaultCopyDef(): List[(ValOrDefDef, Int)] = {
if (replaced.isEmpty) (param, -1) :: Nil
else replaced.minBy(_._1)(versionOrdering)._2.oldDefs
}
}

private def positionVersionedParams[A](params: List[(A, Int)]): List[A] = {
val res = new ListBuffer[A]
val paramIter = params.iterator.filter(_._2 < 0)
@tailrec
def iter(withPositions: List[(A, Int)]): Unit = {
withPositions match {
case (v, pos) :: rest =>
paramIter.take(pos - res.length).foreach { case (x, _) => res += x }
res += v
iter(rest)
case _ =>
paramIter.foreach { case (x, _) => res += x }
}
}
iter(params.filter(_._2 >= 0).sortBy(_._2))
res.toList
}

private def getVersionedParams(
params: List[ValDef],
stats: List[Tree]
): List[VersionedParam] = {
val appendedFields: Map[String, Version] = getNewFieldVersions(params)
val replacedFields: Map[String, (Version, ReplacedField)] = ReplacedField.getMap(params, stats)
val replacedFields: Map[String, Map[Version, ReplacedField]] =
ReplacedField.getMap(params, stats)
params.map { p =>
val pname = p.name.toString
new VersionedParam(p, appendedFields.get(pname), replacedFields.get(pname))
new VersionedParam(p, appendedFields.get(pname), replacedFields.getOrElse(pname, Map.empty))
}
}

Expand Down Expand Up @@ -586,9 +614,13 @@ class AstNamerMacros(val c: Context) extends Reflection with CommonNamerMacros {
builder.result()
}

private class ReplacedField(val oldDef: ValOrDefDef, val newVal: ValDef, ctor: Tree) {
private class ReplacedField(
val newVal: ValDef,
ctor: Tree,
val oldDefs: List[(ValOrDefDef, Int)]
) {
def newValDefn: ValDef = {
val body =
def bodyForSingleOldDef(oldDef: ValOrDefDef) =
if (ctor eq null)
q"""
import scala.meta.trees._
Expand All @@ -598,6 +630,19 @@ class AstNamerMacros(val c: Context) extends Reflection with CommonNamerMacros {
q"""
$ctor(${oldDef.name})
"""
def bodyForMultipleOldDefs = {
if (ctor eq null) c.abort(newVal.pos, s"${newVal.name} must define a ctor")
val names = oldDefs.map { case (oldDef, _) =>
val name = q"${oldDef.name}"
val arg = AssignOrNamedArg(name, name)
q"$arg"
}
q"$ctor(..$names)"
}
val body = oldDefs match {
case (oldDef, _) :: Nil => bodyForSingleOldDef(oldDef)
case _ => bodyForMultipleOldDefs
}
q"""
val ${newVal.name}: ${deannotateType(newVal)} = {
..$body
Expand All @@ -607,27 +652,45 @@ class AstNamerMacros(val c: Context) extends Reflection with CommonNamerMacros {
}

private object ReplacedField {
def getMap(params: List[ValDef], stats: List[Tree]): Map[String, (Version, ReplacedField)] = {
val fields: Map[String, ValDef] = params.map(p => p.name.toString -> p).toMap
stats.flatMap {
def getMap(
params: List[ValDef],
stats: List[Tree]
): Map[String, Map[Version, ReplacedField]] = {
val fields: Map[String, (ValDef, Map[Version, Tree])] = params.map { p =>
val ctorsByVersion = p.mods.annotations.collect {
case q"new replacesFields($since, $ctor)" =>
val version = parseVersionAnnot(since, "replacesFields", "since")
version -> ctor
}.toMap
p.name.toString -> (p, ctorsByVersion)
}.toMap
val replacedFields = stats.flatMap {
case p: ValOrDefDef =>
val anno = p.mods.annotations.collectFirst {
case q"new replacedField($until)" => (until, null)
case q"new replacedField($until, $ctor)" => (until, ctor)
case q"new replacedField($until)" => (until, -1)
case q"new replacedField($until, $pos)" =>
(until, getAnnotAttribute(pos).toInt)
}
anno.map { case (until, ctor) =>
anno.map { case (until, pos) =>
if (!p.mods.hasFlag(Flag.FINAL))
c.abort(p.pos, "replacedField-annotated fields must be final")
val version = parseVersionAnnot(until, "replacedField", "until")
val newField = getNewField(p)
val newVal = fields.getOrElse(
newField,
c.abort(p.pos, s"@replacedField: field `$newField` is undefined)")
)
newVal.name.toString -> (version, new ReplacedField(p, newVal, ctor))
newField -> (p, version, pos)
}
case _ => None
}.toMap
}
replacedFields.groupBy(_._1).map { case (k, v) =>
val (newVal, ctorsByVersion) = fields.getOrElse(
k,
c.abort(v.head._2._1.pos, s"@replacedField: field `$k` is undefined)")
)
k -> v.map(_._2).groupBy(_._2).map { case (ver, oldFields) =>
val ctor = ctorsByVersion.get(ver).orNull
val oldDefs = oldFields.map { case (oldField, _, pos) => (oldField, pos) }
ver -> new ReplacedField(newVal, ctor, oldDefs)
}
}
}

private def getNewField(oldDef: ValOrDefDef): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ object Metadata {
class branch extends StaticAnnotation
class astClass extends StaticAnnotation
class newField(since: String) extends StaticAnnotation
class replacedField(until: String, ctor: Any = null) extends StaticAnnotation
class replacedField(until: String, pos: Int = -1) extends StaticAnnotation
class replacesFields(since: String, ctor: Any) extends StaticAnnotation
class astCompanion extends StaticAnnotation
@getter class astField extends StaticAnnotation
@getter class auxiliary extends StaticAnnotation
Expand Down
24 changes: 13 additions & 11 deletions scalameta/trees/shared/src/main/scala/scala/meta/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import scala.meta.inputs._
import scala.meta.tokens._
import scala.meta.prettyprinters._
import scala.meta.internal.trees._
import scala.meta.internal.trees.Metadata.{newField, replacedField}
import scala.meta.internal.trees.Metadata.{newField, replacedField, replacesFields}
import scala.{meta => sm}

@root trait Tree extends InternalTree {
Expand Down Expand Up @@ -253,29 +253,31 @@ object Type {
}

@ast class FuncParamClause(values: List[Type]) extends Tree with Member.SyntaxValuesClause
object FunctionType {
private[meta] final val paramsToClause: List[Type] => FuncParamClause =
FuncParamClause.apply
}
@branch trait FunctionType extends Type with Tree.WithBody {
def paramClause: FuncParamClause
@deprecated("Please use paramClause instead", "4.6.0")
def params: List[Type]
def res: Type
override def body: Tree = res
}
@ast class Function(paramClause: FuncParamClause, res: Type) extends FunctionType {
@replacedField("4.6.0", FunctionType.paramsToClause) final override def params: List[Type] =
paramClause.values
@ast class Function(
@replacesFields("4.6.0", FuncParamClause)
paramClause: FuncParamClause,
res: Type
) extends FunctionType {
@replacedField("4.6.0") final override def params: List[Type] = paramClause.values
}
@ast class PolyFunction(tparamClause: ParamClause, tpe: Type)
extends Type with Tree.WithTParamClause with Tree.WithBody {
@replacedField("4.6.0") final def tparams: List[Param] = tparamClause.values
override final def body: Tree = tpe
}
@ast class ContextFunction(paramClause: FuncParamClause, res: Type) extends FunctionType {
@replacedField("4.6.0", FunctionType.paramsToClause) final override def params: List[Type] =
paramClause.values
@ast class ContextFunction(
@replacesFields("4.6.0", FuncParamClause)
paramClause: FuncParamClause,
res: Type
) extends FunctionType {
@replacedField("4.6.0") final override def params: List[Type] = paramClause.values
}
@ast @deprecated("Implicit functions are not supported in any dialect")
class ImplicitFunction(
Expand Down

0 comments on commit a49e17a

Please sign in to comment.