Skip to content

Commit

Permalink
Add syntax cap[qual] for outer capture roots
Browse files Browse the repository at this point in the history
  • Loading branch information
odersky committed Sep 1, 2023
1 parent 7913391 commit 9ab1a70
Show file tree
Hide file tree
Showing 11 changed files with 114 additions and 54 deletions.
11 changes: 11 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,17 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] =>
case _ =>
tree.tpe.isInstanceOf[ThisType]
}

/** Under capture checking, an extractor for qualified roots `cap[Q]`.
*/
object QualifiedRoot:

def unapply(tree: Apply)(using Context): Option[String] = tree match
case Apply(fn, Literal(lit) :: Nil) if fn.symbol == defn.Caps_capIn =>
Some(lit.value.asInstanceOf[String])
case _ =>
None
end QualifiedRoot
}

trait UntypedTreeInfo extends TreeInfo[Untyped] { self: Trees.Instance[Untyped] =>
Expand Down
8 changes: 7 additions & 1 deletion compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,10 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
case Floating
}

/** {x1, ..., xN} T (only relevant under captureChecking) */
/** {x1, ..., xN} T (only relevant under captureChecking)
* Created when parsing function types so that capture set and result type
* is combined in a single node.
*/
case class CapturesAndResult(refs: List[Tree], parent: Tree)(implicit @constructorOnly src: SourceFile) extends TypTree

/** A type tree appearing somewhere in the untyped DefDef of a lambda, it will be typed using `tpFun`.
Expand Down Expand Up @@ -512,6 +515,9 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
def captureRoot(using Context): Select =
Select(scalaDot(nme.caps), nme.CAPTURE_ROOT)

def captureRootIn(using Context): Select =
Select(scalaDot(nme.caps), nme.capIn)

def makeRetaining(parent: Tree, refs: List[Tree], annotName: TypeName)(using Context): Annotated =
Annotated(parent, New(scalaAnnotationDot(annotName), List(refs)))

Expand Down
63 changes: 27 additions & 36 deletions compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,14 @@ end mapRoots
extension (tree: Tree)

/** Map tree with CaptureRef type to its type, throw IllegalCaptureRef otherwise */
def toCaptureRef(using Context): CaptureRef = tree.tpe match
case ref: CaptureRef => ref
case tpe => throw IllegalCaptureRef(tpe)
def toCaptureRef(using Context): CaptureRef = tree match
case QualifiedRoot(outer) =>
ctx.owner.levelOwnerNamed(outer)
.orElse(defn.captureRoot) // non-existing outer roots are reported in Setup's checkQualifiedRoots
.localRoot.termRef
case _ => tree.tpe match
case ref: CaptureRef => ref
case tpe => throw IllegalCaptureRef(tpe) // if this was compiled from cc syntax, problem should have been reported at Typer

/** Convert a @retains or @retainsByName annotation tree to the capture set it represents.
* For efficience, the result is cached as an Attachment on the tree.
Expand Down Expand Up @@ -266,39 +271,6 @@ extension (tp: Type)
tp.tp1.isAlwaysPure && tp.tp2.isAlwaysPure
case _ =>
false
/*!!!
def capturedLocalRoot(using Context): Symbol =
tp.captureSet.elems.toList
.filter(_.isLocalRootCapability)
.map(_.termSymbol)
.maxByOption(_.ccNestingLevel)
.getOrElse(NoSymbol)
/** Remap roots defined in `cls` to the ... */
def remapRoots(pre: Type, cls: Symbol)(using Context): Type =
if cls.isStaticOwner then tp
else
val from =
if cls.source == ctx.compilationUnit.source then cls.localRoot
else defn.captureRoot
mapRoots(from, capturedLocalRoot)(tp)
def containsRoot(root: Symbol)(using Context): Boolean =
val search = new TypeAccumulator[Boolean]:
def apply(x: Boolean, t: Type): Boolean =
if x then true
else t.dealias match
case t1: TermRef if t1.symbol == root => true
case t1: TypeRef if t1.classSymbol.hasAnnotation(defn.CapabilityAnnot) => true
case t1: MethodType =>
!foldOver(x, t1.paramInfos) && this(x, t1.resType)
case t1 @ AppliedType(tycon, args) if defn.isFunctionSymbol(tycon.typeSymbol) =>
val (inits, last :: Nil) = args.splitAt(args.length - 1): @unchecked
!foldOver(x, inits) && this(x, last)
case t1 => foldOver(x, t1)
search(false, tp)
*/

extension (cls: ClassSymbol)

Expand Down Expand Up @@ -405,6 +377,7 @@ extension (sym: Symbol)
case psyms :: _ => psyms.find(_.info.typeSymbol == defn.Caps_Cap).getOrElse(NoSymbol)
case _ => NoSymbol

/** The local root corresponding to sym's level owner */
def localRoot(using Context): Symbol =
val owner = sym.levelOwner
assert(owner.exists)
Expand All @@ -415,6 +388,24 @@ extension (sym: Symbol)
else newRoot
ccState.localRoots.getOrElseUpdate(owner, lclRoot)

/** The level owner enclosing `sym` which has the given name, or NoSymbol if none exists.
* If name refers to a val that has a closure as rhs, we return the closure as level
* owner.
*/
def levelOwnerNamed(name: String)(using Context): Symbol =
def recur(owner: Symbol, prev: Symbol): Symbol =
if owner.name.toString == name then
if owner.isLevelOwner then owner
else if owner.isTerm && !owner.isOneOf(Method | Module) && prev.exists then prev
else NoSymbol
else if owner == defn.RootClass then
NoSymbol
else
val prev1 = if owner.isAnonymousFunction && owner.isLevelOwner then owner else NoSymbol
recur(owner.owner, prev1)
recur(sym, NoSymbol)
.showing(i"find outer $sym [ $name ] = $result", capt)

def maxNested(other: Symbol)(using Context): Symbol =
if sym.ccNestingLevel < other.ccNestingLevel then other else sym
/* does not work yet, we do mix sets with different levels, for instance in cc-this.scala.
Expand Down
34 changes: 22 additions & 12 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,15 @@ object CheckCaptures:
report.error(em"Singleton type $parent cannot have capture set", parent.srcPos)
case _ =>
for elem <- retainedElems(ann) do
elem.tpe match
case ref: CaptureRef =>
if !ref.isTrackableRef then
report.error(em"$elem cannot be tracked since it is not a parameter or local value", elem.srcPos)
case tpe =>
report.error(em"$elem: $tpe is not a legal element of a capture set", elem.srcPos)
elem match
case QualifiedRoot(outer) =>
// Will be checked by Setup's checkOuterRoots
case _ => elem.tpe match
case ref: CaptureRef =>
if !ref.isTrackableRef then
report.error(em"$elem cannot be tracked since it is not a parameter or local value", elem.srcPos)
case tpe =>
report.error(em"$elem: $tpe is not a legal element of a capture set", elem.srcPos)

/** If `tp` is a capturing type, check that all references it mentions have non-empty
* capture sets. Also: warn about redundant capture annotations.
Expand All @@ -155,7 +158,7 @@ object CheckCaptures:
if ref.captureSetOfInfo.elems.isEmpty then
report.error(em"$ref cannot be tracked since its capture set is empty", pos)
else if parent.captureSet.accountsFor(ref) then
report.warning(em"redundant capture: $parent already accounts for $ref", pos)
report.warning(em"redundant capture: $parent already accounts for $ref in $tp", pos)
case _ =>

/** Warn if `ann`, which is the tree of a @retains annotation, defines some elements that
Expand All @@ -166,11 +169,15 @@ object CheckCaptures:
def warnIfRedundantCaptureSet(ann: Tree, tpt: Tree)(using Context): Unit =
var retained = retainedElems(ann).toArray
for i <- 0 until retained.length do
val ref = retained(i).toCaptureRef
val refTree = retained(i)
val ref = refTree.toCaptureRef
val others = for j <- 0 until retained.length if j != i yield retained(j).toCaptureRef
val remaining = CaptureSet(others*)
if remaining.accountsFor(ref) then
val srcTree = if ann.span.exists then ann else tpt
val srcTree =
if refTree.span.exists then refTree
else if ann.span.exists then ann
else tpt
report.warning(em"redundant capture: $remaining already accounts for $ref", srcTree.srcPos)

/** Attachment key for bodies of closures, provided they are values */
Expand Down Expand Up @@ -1192,9 +1199,12 @@ class CheckCaptures extends Recheck, SymTransformer:
def postCheck(unit: tpd.Tree)(using Context): Unit =
val checker = new TreeTraverser:
def traverse(tree: Tree)(using Context): Unit =
traverseChildren(tree)
val lctx = tree match
case _: DefTree | _: TypeDef if tree.symbol.exists => ctx.withOwner(tree.symbol)
case _ => ctx
traverseChildren(tree)(using lctx)
check(tree)
def check(tree: Tree) = tree match
def check(tree: Tree)(using Context) = tree match
case _: InferredTypeTree =>
case tree: TypeTree if !tree.span.isZeroExtent =>
tree.knownType.foreachPart { tp =>
Expand Down Expand Up @@ -1253,7 +1263,7 @@ class CheckCaptures extends Recheck, SymTransformer:
case _ =>
end check
end checker
checker.traverse(unit)
checker.traverse(unit)(using ctx.withOwner(defn.RootClass))
if !ctx.reporter.errorsReported then
// We dont report errors here if previous errors were reported, because other
// errors often result in bad applied types, but flagging these bad types gives
Expand Down
10 changes: 8 additions & 2 deletions compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,11 @@ extends tpd.TreeTraverser:
then CapturingType(tp, CaptureSet.universal, boxed = false)
else tp

private def checkQualifiedRoots(tree: Tree)(using Context): Unit =
for case elem @ QualifiedRoot(outer) <- retainedElems(tree) do
if !ctx.owner.levelOwnerNamed(outer).exists then
report.error(em"`$outer` does not name an outer definition that represents a capture level", elem.srcPos)

private def expandAliases(using Context) = new TypeMap with FollowAliases:
override def toString = "expand aliases"
def apply(t: Type) =
Expand All @@ -226,12 +231,13 @@ extends tpd.TreeTraverser:
if t2 ne t then return t2
t match
case t @ AnnotatedType(t1, ann) =>
val t2 =
checkQualifiedRoots(ann.tree)
val t3 =
if ann.symbol == defn.RetainsAnnot && isCapabilityClassRef(t1) then t1
else this(t1)
// Don't map capture sets, since that would implicitly normalize sets that
// are not well-formed.
t.derivedAnnotatedType(t2, ann)
t.derivedAnnotatedType(t3, ann)
case _ =>
mapOverFollowingAliases(t)

Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,7 @@ class Definitions {
@tu lazy val CapsModule: Symbol = requiredModule("scala.caps")
@tu lazy val captureRoot: TermSymbol = CapsModule.requiredValue("cap")
@tu lazy val Caps_Cap: TypeSymbol = CapsModule.requiredType("Cap")
@tu lazy val Caps_capIn: TermSymbol = CapsModule.requiredMethod("capIn")
@tu lazy val CapsUnsafeModule: Symbol = requiredModule("scala.caps.unsafe")
@tu lazy val Caps_unsafeAssumePure: Symbol = CapsUnsafeModule.requiredMethod("unsafeAssumePure")
@tu lazy val Caps_unsafeBox: Symbol = CapsUnsafeModule.requiredMethod("unsafeBox")
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ object StdNames {
val bytes: N = "bytes"
val canEqual_ : N = "canEqual"
val canEqualAny : N = "canEqualAny"
val capIn: N = "capIn"
val caps: N = "caps"
val captureChecking: N = "captureChecking"
val checkInitialized: N = "checkInitialized"
Expand Down
16 changes: 13 additions & 3 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1423,13 +1423,23 @@ object Parsers {
case _ => None
}

/** CaptureRef ::= ident | `this`
/** CaptureRef ::= ident | `this` | `cap` [`[` ident `]`]
*/
def captureRef(): Tree =
if in.token == THIS then simpleRef()
else termIdent() match
case Ident(nme.CAPTURE_ROOT) => captureRoot
case id => id
case id @ Ident(nme.CAPTURE_ROOT) =>
if in.token == LBRACKET then
val ref = atSpan(id.span.start)(captureRootIn)
val qual =
inBrackets:
atSpan(in.offset):
Literal(Constant(ident().toString))
atSpan(id.span.start)(Apply(ref, qual :: Nil))
else
atSpan(id.span.start)(captureRoot)
case id =>
id

/** CaptureSet ::= `{` CaptureRef {`,` CaptureRef} `}` -- under captureChecking
*/
Expand Down
2 changes: 2 additions & 0 deletions library/src/scala/caps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import annotation.experimental

given Cap = cap

def capIn(scope: String): Cap = ()

object unsafe:

extension [T](x: T)
Expand Down
7 changes: 7 additions & 0 deletions tests/neg-custom-args/captures/localcaps.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class C:
def x: C^{cap[d]} = ??? // error

def y: C^{cap[C]} = ??? // ok
private val z = (x: Int) => (c: C^{cap[z]}) => x // ok

private val z2 = identity((x: Int) => (c: C^{cap[z2]}) => x) // error
15 changes: 15 additions & 0 deletions tests/pos-custom-args/captures/pairs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,18 @@ object Monomorphic:
val x1c: Cap ->{c} Unit = x1
val y1 = p.snd
val y1c: Cap ->{d} Unit = y1

object Monomorphic2:

class Pair(x: Cap => Unit, y: Cap => Unit):
def fst: Cap^{cap[Pair]} ->{x} Unit = x
def snd: Cap^{cap[Pair]} ->{y} Unit = y

def test(c: Cap, d: Cap) =
def f(x: Cap): Unit = if c == x then ()
def g(x: Cap): Unit = if d == x then ()
val p = Pair(f, g)
val x1 = p.fst
val x1c: Cap ->{c} Unit = x1
val y1 = p.snd
val y1c: Cap ->{d} Unit = y1

0 comments on commit 9ab1a70

Please sign in to comment.