Skip to content

Commit

Permalink
Add support for companion in MacroAnnotations
Browse files Browse the repository at this point in the history
  • Loading branch information
hamzaremmal committed Apr 15, 2024
1 parent 54d67e0 commit ef74f4d
Show file tree
Hide file tree
Showing 68 changed files with 561 additions and 302 deletions.
164 changes: 132 additions & 32 deletions compiler/src/dotty/tools/dotc/transform/Inlining.scala
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
package dotty.tools.dotc
package transform

import ast.tpd
import ast.Trees.*
import ast.TreeMapWithImplicits
import core.*
import Flags.*
import Decorators.*
import Contexts.*
import Symbols.*
import Decorators.*
import config.Printers.inlining
import DenotTransformers.IdentityDenotTransformer
import inlines.Inlines
import quoted.*
import staging.StagingLevel

import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.ast.Trees.*
import dotty.tools.dotc.quoted.*
import dotty.tools.dotc.inlines.Inlines
import dotty.tools.dotc.ast.TreeMapWithImplicits
import dotty.tools.dotc.core.DenotTransformers.IdentityDenotTransformer
import dotty.tools.dotc.staging.StagingLevel

import scala.collection.mutable.ListBuffer
import scala.collection.mutable

/** Inlines all calls to inline methods that are not in an inline method or a quote */
class Inlining extends MacroTransform, IdentityDenotTransformer {
Expand Down Expand Up @@ -56,38 +58,82 @@ class Inlining extends MacroTransform, IdentityDenotTransformer {

def newTransformer(using Context): Transformer = new Transformer {
override def transform(tree: tpd.Tree)(using Context): tpd.Tree =
new InliningTreeMap().transform(tree)
InliningTreeMap().transform(tree)
}

private class MemoizeStatsTreeMap extends TreeMapWithImplicits {

// It is safe to assume that the companion of a tree is in the same scope
// Therefore, when expanding MacroAnnotations, we will only keep track of
// the trees in the same scope as the current transformed tree

override def transform(tree: Tree)(using Context): Tree =
tree match
case PackageDef(_, stats) =>
// Phase I: Collect and memoize all the stats
val treesToTrack = stats.collect { case m: MemberDef => (m.symbol, m) }
val withTrackedTreeCtx = MacroAnnotations.trackedTreesCtx(mutable.Map(treesToTrack*))
// Phase II & III: Transform the tree with this definitions and reconcile them with the tracked trees
super.transform(tree)(using withTrackedTreeCtx) match
case pkg@PackageDef(pid, stats) =>
val trackedTree = MacroAnnotations.trackedTrees(using withTrackedTreeCtx)
val updatedStats = stats.mapConserve:
case tree: MemberDef if trackedTree.contains(tree.symbol) =>
trackedTree(tree.symbol)
case stat => stat
cpy.PackageDef(pkg)(pid = pid, stats = updatedStats)
case tree => tree
case block: Block =>
// Phase I: Fetch all the member definitions in the block
val trackedTrees = block.stats.collect { case m: MemberDef => (m.symbol, m) }
val withTrackedTreeCtx = MacroAnnotations.trackedTreesCtx(mutable.Map(trackedTrees*))

// Phase II / III: Transform the tree and Reconcile between the symbols in syms and the tree
// TODO: Should we define a substitution method where we change the trees
// and not the symbols (see Tree::subst)
// result.subst(MacroAnnotations.trackedTrees(using withTrackedTreeCtx))
super.transform(tree)(using withTrackedTreeCtx) match
case b@Block(stats, expr) =>
val trackedTree = MacroAnnotations.trackedTrees(using withTrackedTreeCtx)
cpy.Block(b)(
expr = expr,
stats = stats.mapConserve:
case ddef: MemberDef if trackedTree.contains(ddef.symbol) =>
trackedTree(ddef.symbol)
case stat => stat
)
case tree => tree
case TypeDef(_, impl: Template) =>
// Phase I: Collect and memoize all the stats
val treesToTrack = impl.body.collect { case m: MemberDef => (m.symbol, m) }
val withTrackedTreeCtx = MacroAnnotations.trackedTreesCtx(mutable.Map(treesToTrack*))
// Phase II / III: Transform the tree and Reconcile between the symbols in syms and the tree
super.transform(tree)(using withTrackedTreeCtx) match
case tree@TypeDef(name, impl: Template) =>
val trackedTree = MacroAnnotations.trackedTrees(using withTrackedTreeCtx)
cpy.TypeDef(tree)(
name = name,
rhs = cpy.Template(impl)(
body = impl.body.mapConserve:
case ddef: MemberDef if trackedTree.contains(ddef.symbol) =>
trackedTree(ddef.symbol)
case stat => stat
)
)
case tree => tree
case _ => super.transform(tree)
}

private class InliningTreeMap extends TreeMapWithImplicits {
private class InliningTreeMap extends MemoizeStatsTreeMap {

/** List of top level classes added by macro annotation in a package object.
* These are added to the PackageDef that owns this particular package object.
*/
private val newTopClasses = MutableSymbolMap[ListBuffer[Tree]]()
private val newTopClasses = MutableSymbolMap[mutable.ListBuffer[Tree]]()

override def transform(tree: Tree)(using Context): Tree = {
tree match
case tree: MemberDef =>
if tree.symbol.is(Inline) then tree
else if tree.symbol.is(Param) then super.transform(tree)
else if
!tree.symbol.isPrimaryConstructor
&& StagingLevel.level == 0
&& MacroAnnotations.hasMacroAnnotation(tree.symbol)
then
val trees = (new MacroAnnotations(self)).expandAnnotations(tree)
val trees1 = trees.map(super.transform)

// Find classes added to the top level from a package object
val (topClasses, trees2) =
if ctx.owner.isPackageObject then trees1.partition(_.symbol.owner == ctx.owner.owner)
else (Nil, trees1)
if topClasses.nonEmpty then
newTopClasses.getOrElseUpdate(ctx.owner.owner, new ListBuffer) ++= topClasses

flatTree(trees2)
else super.transform(tree)
case tree: MemberDef => transformMemberDef(tree)
case _: Typed | _: Block =>
super.transform(tree)
case _: PackageDef =>
Expand All @@ -113,7 +159,61 @@ class Inlining extends MacroTransform, IdentityDenotTransformer {
else Inlines.inlineCall(tree1)
else super.transform(tree)
}

private def transformMemberDef(tree: MemberDef)(using Context) : Tree =
if tree.symbol.is(Inline) then tree
else if tree.symbol.is(Param) then
super.transform(tree)
else if
!tree.symbol.isPrimaryConstructor
&& StagingLevel.level == 0
&& MacroAnnotations.hasMacroAnnotation(tree.symbol)
then
// Fetch the companion's tree
val companionSym =
if tree.symbol.is(ModuleClass) then tree.symbol.companionClass
else if tree.symbol.is(ModuleVal) then NoSymbol
else tree.symbol.companionModule.moduleClass

val companionTree = MacroAnnotations.findTrackedTree(companionSym)

// Fetch the latest tracked tree (It might have already been processed by its companion)
val latestTree = MacroAnnotations.findTrackedTree(tree.symbol)
.getOrElse(tree)

// Expand and process MacroAnnotations
val (trees, companion) =
MacroAnnotations(self).expandAnnotations(latestTree, companionTree)

// Update the tracked trees
for case tree : MemberDef <- trees do
MacroAnnotations.updateTrackedTree(tree.symbol, tree)
for tree <- companion do
MacroAnnotations.updateTrackedTree(tree.symbol, tree)

// Perform inlining on the expansion of the annotations
val trees1 = trees.map(super.transform)

for case tree: MemberDef <- trees1 do
MacroAnnotations.updateTrackedTree(tree.symbol, tree)

// Find classes added to the top level from a package object
val (topClasses, trees2) =
if ctx.owner.isPackageObject then trees1.partition(_.symbol.owner == ctx.owner.owner)
else (Nil, trees1)
if topClasses.nonEmpty then
newTopClasses.getOrElseUpdate(ctx.owner.owner, new mutable.ListBuffer) ++= topClasses
flatTree(trees2)
else
super.transform(tree) match
case tree: MemberDef =>
MacroAnnotations.updateTrackedTree(tree.symbol, tree)
tree
case tree => tree
end transformMemberDef

}

}

object Inlining:
Expand Down
Loading

0 comments on commit ef74f4d

Please sign in to comment.