Skip to content

Commit

Permalink
Allow macro annotation to transform companion (#19677)
Browse files Browse the repository at this point in the history
### Allow MacroAnnotations to update the companion of a definition

We extend the MacroAnnotation api to allow to modify the companion of a
class or an object.

### Specification

1. Order of expansion

- We expand the definitions in program order. 
- We expand the annotations of the outer scope first, then we expand the
inner definitions.
- Annotations are expanded from the outer annotation to the inner
annotation.

In the following example, we expand the annotations in this order: `a1`,
`a2`, `a3`.

```scala
@A1 @a2
class Foo:
  @A3 def foo = ???
```
2. Expansion of the companion

We always expand the latest available tree. If an annotation defined on
`class Foo` changes its companion (`object Foo`) and the `class` is
defined before `object`, the expansion of the annotations on the
`object` will be performed on the result of the expansion of `class`.

3. The program order is maintained

We maintain the program order in the definitions that were expanded.

4. Backtrack and reprocess

Example:

```scala
@A1 class Foo
@a2 object Foo
```
If the `@a2` annotation changes the definitions in `class Foo`, we will
rerun the algorithm on the result of this new expansion. Please note
that we don't allow to generate code with MacroAnnotations, the reason
for rerunning the algorithm is to expand and inline possible macros that
we generated.

---
Closes #19676
  • Loading branch information
nicolasstucki committed Apr 30, 2024
2 parents e2c456f + 4694b3b commit 4f39236
Show file tree
Hide file tree
Showing 75 changed files with 631 additions and 350 deletions.
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/CompilationUnit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import config.{SourceVersion, Feature}
import StdNames.nme
import scala.annotation.internal.sharable
import scala.util.control.NoStackTrace
import transform.MacroAnnotations
import transform.MacroAnnotations.isMacroAnnotation

class CompilationUnit protected (val source: SourceFile, val info: CompilationUnitInfo | Null) {

Expand Down Expand Up @@ -197,7 +197,7 @@ object CompilationUnit {
case _ =>
case _ =>
for annot <- tree.symbol.annotations do
if MacroAnnotations.isMacroAnnotation(annot) then
if annot.isMacroAnnotation then
ctx.compilationUnit.hasMacroAnnotations = true
traverseChildren(tree)
}
Expand Down
77 changes: 77 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/TreeMapWithTrackedStats.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package dotty.tools.dotc
package ast

import tpd.*
import core.Contexts.*
import core.Symbols.*
import util.Property

import scala.collection.mutable

/**
* It is safe to assume that the companion of a tree is in the same scope.
* Therefore, when expanding MacroAnnotations, we will only keep track of
* the trees in the same scope as the current transformed tree
*/
abstract class TreeMapWithTrackedStats extends TreeMapWithImplicits:

import TreeMapWithTrackedStats.*

/** Fetch the corresponding tracked tree for a given symbol */
protected final def getTracked(sym: Symbol)(using Context): Option[MemberDef] =
for trees <- ctx.property(TrackedTrees)
tree <- trees.get(sym)
yield tree

/** Update the tracked trees */
protected final def updateTracked(tree: Tree)(using Context): Tree =
tree match
case tree: MemberDef =>
trackedTrees.update(tree.symbol, tree)
tree
case _ => tree
end updateTracked

/** Process a list of trees and give the priority to trakced trees */
private final def withUpdatedTrackedTrees(stats: List[Tree])(using Context) =
val trackedTrees = TreeMapWithTrackedStats.trackedTrees
stats.mapConserve:
case tree: MemberDef if trackedTrees.contains(tree.symbol) =>
trackedTrees(tree.symbol)
case stat => stat

override def transform(tree: Tree)(using Context): Tree =
tree match
case PackageDef(_, stats) =>
inContext(trackedDefinitionsCtx(stats)): // Step I: Collect and memoize all the definition trees
// Step II: Transform the tree
val pkg@PackageDef(pid, stats) = super.transform(tree): @unchecked
// Step III: Reconcile between the symbols in syms and the tree
cpy.PackageDef(pkg)(pid = pid, stats = withUpdatedTrackedTrees(stats))
case block: Block =>
inContext(trackedDefinitionsCtx(block.stats)): // Step I: Collect all the member definitions in the block
// Step II: Transform the tree
val b@Block(stats, expr) = super.transform(tree): @unchecked
// Step III: Reconcile between the symbols in syms and the tree
cpy.Block(b)(expr = expr, stats = withUpdatedTrackedTrees(stats))
case TypeDef(_, impl: Template) =>
inContext(trackedDefinitionsCtx(impl.body)): // Step I: Collect and memoize all the stats
// Step II: Transform the tree
val newTree@TypeDef(name, impl: Template) = super.transform(tree): @unchecked
// Step III: Reconcile between the symbols in syms and the tree
cpy.TypeDef(newTree)(rhs = cpy.Template(impl)(body = withUpdatedTrackedTrees(impl.body)))
case _ => super.transform(tree)

end TreeMapWithTrackedStats

object TreeMapWithTrackedStats:
private val TrackedTrees = new Property.Key[mutable.Map[Symbol, tpd.MemberDef]]

/** Fetch the tracked trees in the cuurent context */
private def trackedTrees(using Context): mutable.Map[Symbol, MemberDef] =
ctx.property(TrackedTrees).get

/** Build a context and track the provided MemberDef trees */
private def trackedDefinitionsCtx(stats: List[Tree])(using Context): Context =
val treesToTrack = stats.collect { case m: MemberDef => (m.symbol, m) }
ctx.fresh.setProperty(TrackedTrees, mutable.Map(treesToTrack*))
91 changes: 60 additions & 31 deletions compiler/src/dotty/tools/dotc/transform/Inlining.scala
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
package dotty.tools.dotc
package transform

import ast.tpd
import ast.Trees.*
import ast.TreeMapWithTrackedStats
import core.*
import Flags.*
import Decorators.*
import Contexts.*
import Symbols.*
import Decorators.*
import config.Printers.inlining
import DenotTransformers.IdentityDenotTransformer
import MacroAnnotations.hasMacroAnnotation
import inlines.Inlines
import quoted.*
import staging.StagingLevel
import util.Property

import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.ast.Trees.*
import dotty.tools.dotc.quoted.*
import dotty.tools.dotc.inlines.Inlines
import dotty.tools.dotc.ast.TreeMapWithImplicits
import dotty.tools.dotc.core.DenotTransformers.IdentityDenotTransformer
import dotty.tools.dotc.staging.StagingLevel

import scala.collection.mutable.ListBuffer
import scala.collection.mutable

/** Inlines all calls to inline methods that are not in an inline method or a quote */
class Inlining extends MacroTransform, IdentityDenotTransformer {
Expand Down Expand Up @@ -56,38 +60,21 @@ class Inlining extends MacroTransform, IdentityDenotTransformer {

def newTransformer(using Context): Transformer = new Transformer {
override def transform(tree: tpd.Tree)(using Context): tpd.Tree =
new InliningTreeMap().transform(tree)
InliningTreeMap().transform(tree)
}

private class InliningTreeMap extends TreeMapWithImplicits {
private class InliningTreeMap extends TreeMapWithTrackedStats {

/** List of top level classes added by macro annotation in a package object.
* These are added to the PackageDef that owns this particular package object.
*/
private val newTopClasses = MutableSymbolMap[ListBuffer[Tree]]()
private val newTopClasses = MutableSymbolMap[mutable.ListBuffer[Tree]]()

override def transform(tree: Tree)(using Context): Tree = {
tree match
case tree: MemberDef =>
if tree.symbol.is(Inline) then tree
else if tree.symbol.is(Param) then super.transform(tree)
else if
!tree.symbol.isPrimaryConstructor
&& StagingLevel.level == 0
&& MacroAnnotations.hasMacroAnnotation(tree.symbol)
then
val trees = (new MacroAnnotations(self)).expandAnnotations(tree)
val trees1 = trees.map(super.transform)

// Find classes added to the top level from a package object
val (topClasses, trees2) =
if ctx.owner.isPackageObject then trees1.partition(_.symbol.owner == ctx.owner.owner)
else (Nil, trees1)
if topClasses.nonEmpty then
newTopClasses.getOrElseUpdate(ctx.owner.owner, new ListBuffer) ++= topClasses

flatTree(trees2)
else super.transform(tree)
// Fetch the latest tracked tree (It might have already been transformed by its companion)
transformMemberDef(getTracked(tree.symbol).getOrElse(tree))
case _: Typed | _: Block =>
super.transform(tree)
case _: PackageDef =>
Expand All @@ -113,7 +100,49 @@ class Inlining extends MacroTransform, IdentityDenotTransformer {
else Inlines.inlineCall(tree1)
else super.transform(tree)
}

private def transformMemberDef(tree: MemberDef)(using Context) : Tree =
if tree.symbol.is(Inline) then tree
else if tree.symbol.is(Param) then
super.transform(tree)
else if
!tree.symbol.isPrimaryConstructor
&& StagingLevel.level == 0
&& tree.symbol.hasMacroAnnotation
then
// Fetch the companion's tree
val companionSym =
if tree.symbol.is(ModuleClass) then tree.symbol.companionClass
else if tree.symbol.is(ModuleVal) then NoSymbol
else tree.symbol.companionModule.moduleClass

// Expand and process MacroAnnotations
val companion = getTracked(companionSym)
val (trees, newCompanion) = MacroAnnotations.expandAnnotations(tree, companion)

// Enter the new symbols & Update the tracked trees
(newCompanion.toList ::: trees).foreach: tree =>
MacroAnnotations.enterMissingSymbols(tree, self)

// Perform inlining on the expansion of the annotations
val trees1 = trees.map(super.transform)
trees1.foreach(updateTracked)
if newCompanion ne companion then
newCompanion.map(super.transform).foreach(updateTracked)

// Find classes added to the top level from a package object
val (topClasses, trees2) =
if ctx.owner.isPackageObject then trees1.partition(_.symbol.owner == ctx.owner.owner)
else (Nil, trees1)
if topClasses.nonEmpty then
newTopClasses.getOrElseUpdate(ctx.owner.owner, new mutable.ListBuffer) ++= topClasses
flatTree(trees2)
else
updateTracked(super.transform(tree))
end transformMemberDef

}

}

object Inlining:
Expand Down
Loading

0 comments on commit 4f39236

Please sign in to comment.