-
Notifications
You must be signed in to change notification settings - Fork 1k
/
TreeMapWithImplicits.scala
129 lines (117 loc) · 3.99 KB
/
TreeMapWithImplicits.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package dotty.tools.dotc
package ast
import Trees._
import core.Contexts._
import core.ContextOps.enter
import core.Flags._
import core.Symbols._
import core.TypeError
import scala.annotation.tailrec
/** A TreeMap that maintains the necessary infrastructure to support
* contextual implicit searches (type-scope implicits are supported anyway).
*
* This incudes implicits defined in scope as well as imported implicits.
*/
class TreeMapWithImplicits extends tpd.TreeMap {
import tpd._
def transformSelf(vd: ValDef)(using Context): ValDef =
cpy.ValDef(vd)(tpt = transform(vd.tpt))
/** Transform statements, while maintaining import contexts and expression contexts
* in the same way as Typer does. The code addresses additional concerns:
* - be tail-recursive where possible
* - don't re-allocate trees where nothing has changed
*/
def transformStats(stats: List[Tree], exprOwner: Symbol)(using Context): List[Tree] = {
@tailrec def traverse(curStats: List[Tree])(using Context): List[Tree] = {
def recur(stats: List[Tree], changed: Tree, rest: List[Tree])(using Context): List[Tree] =
if (stats eq curStats) {
val rest1 = transformStats(rest, exprOwner)
changed match {
case Thicket(trees) => trees ::: rest1
case tree => tree :: rest1
}
}
else stats.head :: recur(stats.tail, changed, rest)
curStats match {
case stat :: rest =>
val statCtx = stat match {
case stat: DefTree => ctx
case _ => ctx.exprContext(stat, exprOwner)
}
val restCtx = stat match {
case stat: Import => ctx.importContext(stat, stat.symbol)
case _ => ctx
}
val stat1 = transform(stat)(using statCtx)
if (stat1 ne stat) recur(stats, stat1, rest)(using restCtx)
else traverse(rest)(using restCtx)
case nil =>
stats
}
}
traverse(stats)
}
private def nestedScopeCtx(defs: List[Tree])(using Context): Context = {
val nestedCtx = ctx.fresh.setNewScope
defs foreach {
case d: DefTree if d.symbol.isOneOf(GivenOrImplicit) => nestedCtx.enter(d.symbol)
case _ =>
}
nestedCtx
}
private def patternScopeCtx(pattern: Tree)(using Context): Context = {
val nestedCtx = ctx.fresh.setNewScope
new TreeTraverser {
def traverse(tree: Tree)(using Context): Unit = {
tree match {
case d: DefTree if d.symbol.isOneOf(GivenOrImplicit) =>
nestedCtx.enter(d.symbol)
case _ =>
}
traverseChildren(tree)
}
}.traverse(pattern)
nestedCtx
}
override def transform(tree: Tree)(using Context): Tree = {
def localCtx =
if (tree.hasType && tree.symbol.exists) ctx.withOwner(tree.symbol) else ctx
try tree match {
case tree: Block =>
super.transform(tree)(using nestedScopeCtx(tree.stats))
case tree: DefDef =>
given Context = localCtx
cpy.DefDef(tree)(
tree.name,
transformSub(tree.tparams),
tree.vparamss mapConserve (transformSub(_)),
transform(tree.tpt),
transform(tree.rhs)(using nestedScopeCtx(tree.vparamss.flatten)))
case EmptyValDef =>
tree
case _: PackageDef | _: MemberDef =>
super.transform(tree)(using localCtx)
case impl @ Template(constr, parents, self, _) =>
cpy.Template(tree)(
transformSub(constr),
transform(parents)(using ctx.superCallContext),
Nil,
transformSelf(self),
transformStats(impl.body, tree.symbol))
case tree: CaseDef =>
val patCtx = patternScopeCtx(tree.pat)(using ctx)
cpy.CaseDef(tree)(
transform(tree.pat),
transform(tree.guard)(using patCtx),
transform(tree.body)(using patCtx)
)
case _ =>
super.transform(tree)
}
catch {
case ex: TypeError =>
report.error(ex, tree.srcPos)
tree
}
}
}