From 9ab1a708ddacbaafda22bf8cca2f3e48670b8103 Mon Sep 17 00:00:00 2001 From: odersky Date: Fri, 1 Sep 2023 14:58:55 +0200 Subject: [PATCH] Add syntax `cap[qual]` for outer capture roots --- .../src/dotty/tools/dotc/ast/TreeInfo.scala | 11 ++++ compiler/src/dotty/tools/dotc/ast/untpd.scala | 8 ++- .../src/dotty/tools/dotc/cc/CaptureOps.scala | 63 ++++++++----------- .../dotty/tools/dotc/cc/CheckCaptures.scala | 34 ++++++---- compiler/src/dotty/tools/dotc/cc/Setup.scala | 10 ++- .../dotty/tools/dotc/core/Definitions.scala | 1 + .../src/dotty/tools/dotc/core/StdNames.scala | 1 + .../dotty/tools/dotc/parsing/Parsers.scala | 16 ++++- library/src/scala/caps.scala | 2 + .../neg-custom-args/captures/localcaps.scala | 7 +++ tests/pos-custom-args/captures/pairs.scala | 15 +++++ 11 files changed, 114 insertions(+), 54 deletions(-) create mode 100644 tests/neg-custom-args/captures/localcaps.scala diff --git a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala index e60d6e86754c..c569fe047b66 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala @@ -376,6 +376,17 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] => case _ => tree.tpe.isInstanceOf[ThisType] } + + /** Under capture checking, an extractor for qualified roots `cap[Q]`. + */ + object QualifiedRoot: + + def unapply(tree: Apply)(using Context): Option[String] = tree match + case Apply(fn, Literal(lit) :: Nil) if fn.symbol == defn.Caps_capIn => + Some(lit.value.asInstanceOf[String]) + case _ => + None + end QualifiedRoot } trait UntypedTreeInfo extends TreeInfo[Untyped] { self: Trees.Instance[Untyped] => diff --git a/compiler/src/dotty/tools/dotc/ast/untpd.scala b/compiler/src/dotty/tools/dotc/ast/untpd.scala index 8cc0750de53c..e7d38da854a4 100644 --- a/compiler/src/dotty/tools/dotc/ast/untpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/untpd.scala @@ -149,7 +149,10 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { case Floating } - /** {x1, ..., xN} T (only relevant under captureChecking) */ + /** {x1, ..., xN} T (only relevant under captureChecking) + * Created when parsing function types so that capture set and result type + * is combined in a single node. + */ case class CapturesAndResult(refs: List[Tree], parent: Tree)(implicit @constructorOnly src: SourceFile) extends TypTree /** A type tree appearing somewhere in the untyped DefDef of a lambda, it will be typed using `tpFun`. @@ -512,6 +515,9 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { def captureRoot(using Context): Select = Select(scalaDot(nme.caps), nme.CAPTURE_ROOT) + def captureRootIn(using Context): Select = + Select(scalaDot(nme.caps), nme.capIn) + def makeRetaining(parent: Tree, refs: List[Tree], annotName: TypeName)(using Context): Annotated = Annotated(parent, New(scalaAnnotationDot(annotName), List(refs))) diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala index 2a860e42cfe0..7f7468675aae 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala @@ -103,9 +103,14 @@ end mapRoots extension (tree: Tree) /** Map tree with CaptureRef type to its type, throw IllegalCaptureRef otherwise */ - def toCaptureRef(using Context): CaptureRef = tree.tpe match - case ref: CaptureRef => ref - case tpe => throw IllegalCaptureRef(tpe) + def toCaptureRef(using Context): CaptureRef = tree match + case QualifiedRoot(outer) => + ctx.owner.levelOwnerNamed(outer) + .orElse(defn.captureRoot) // non-existing outer roots are reported in Setup's checkQualifiedRoots + .localRoot.termRef + case _ => tree.tpe match + case ref: CaptureRef => ref + case tpe => throw IllegalCaptureRef(tpe) // if this was compiled from cc syntax, problem should have been reported at Typer /** Convert a @retains or @retainsByName annotation tree to the capture set it represents. * For efficience, the result is cached as an Attachment on the tree. @@ -266,39 +271,6 @@ extension (tp: Type) tp.tp1.isAlwaysPure && tp.tp2.isAlwaysPure case _ => false -/*!!! - def capturedLocalRoot(using Context): Symbol = - tp.captureSet.elems.toList - .filter(_.isLocalRootCapability) - .map(_.termSymbol) - .maxByOption(_.ccNestingLevel) - .getOrElse(NoSymbol) - - /** Remap roots defined in `cls` to the ... */ - def remapRoots(pre: Type, cls: Symbol)(using Context): Type = - if cls.isStaticOwner then tp - else - val from = - if cls.source == ctx.compilationUnit.source then cls.localRoot - else defn.captureRoot - mapRoots(from, capturedLocalRoot)(tp) - - - def containsRoot(root: Symbol)(using Context): Boolean = - val search = new TypeAccumulator[Boolean]: - def apply(x: Boolean, t: Type): Boolean = - if x then true - else t.dealias match - case t1: TermRef if t1.symbol == root => true - case t1: TypeRef if t1.classSymbol.hasAnnotation(defn.CapabilityAnnot) => true - case t1: MethodType => - !foldOver(x, t1.paramInfos) && this(x, t1.resType) - case t1 @ AppliedType(tycon, args) if defn.isFunctionSymbol(tycon.typeSymbol) => - val (inits, last :: Nil) = args.splitAt(args.length - 1): @unchecked - !foldOver(x, inits) && this(x, last) - case t1 => foldOver(x, t1) - search(false, tp) -*/ extension (cls: ClassSymbol) @@ -405,6 +377,7 @@ extension (sym: Symbol) case psyms :: _ => psyms.find(_.info.typeSymbol == defn.Caps_Cap).getOrElse(NoSymbol) case _ => NoSymbol + /** The local root corresponding to sym's level owner */ def localRoot(using Context): Symbol = val owner = sym.levelOwner assert(owner.exists) @@ -415,6 +388,24 @@ extension (sym: Symbol) else newRoot ccState.localRoots.getOrElseUpdate(owner, lclRoot) + /** The level owner enclosing `sym` which has the given name, or NoSymbol if none exists. + * If name refers to a val that has a closure as rhs, we return the closure as level + * owner. + */ + def levelOwnerNamed(name: String)(using Context): Symbol = + def recur(owner: Symbol, prev: Symbol): Symbol = + if owner.name.toString == name then + if owner.isLevelOwner then owner + else if owner.isTerm && !owner.isOneOf(Method | Module) && prev.exists then prev + else NoSymbol + else if owner == defn.RootClass then + NoSymbol + else + val prev1 = if owner.isAnonymousFunction && owner.isLevelOwner then owner else NoSymbol + recur(owner.owner, prev1) + recur(sym, NoSymbol) + .showing(i"find outer $sym [ $name ] = $result", capt) + def maxNested(other: Symbol)(using Context): Symbol = if sym.ccNestingLevel < other.ccNestingLevel then other else sym /* does not work yet, we do mix sets with different levels, for instance in cc-this.scala. diff --git a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala index a181b521ab52..0afd9137e6dd 100644 --- a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala @@ -138,12 +138,15 @@ object CheckCaptures: report.error(em"Singleton type $parent cannot have capture set", parent.srcPos) case _ => for elem <- retainedElems(ann) do - elem.tpe match - case ref: CaptureRef => - if !ref.isTrackableRef then - report.error(em"$elem cannot be tracked since it is not a parameter or local value", elem.srcPos) - case tpe => - report.error(em"$elem: $tpe is not a legal element of a capture set", elem.srcPos) + elem match + case QualifiedRoot(outer) => + // Will be checked by Setup's checkOuterRoots + case _ => elem.tpe match + case ref: CaptureRef => + if !ref.isTrackableRef then + report.error(em"$elem cannot be tracked since it is not a parameter or local value", elem.srcPos) + case tpe => + report.error(em"$elem: $tpe is not a legal element of a capture set", elem.srcPos) /** If `tp` is a capturing type, check that all references it mentions have non-empty * capture sets. Also: warn about redundant capture annotations. @@ -155,7 +158,7 @@ object CheckCaptures: if ref.captureSetOfInfo.elems.isEmpty then report.error(em"$ref cannot be tracked since its capture set is empty", pos) else if parent.captureSet.accountsFor(ref) then - report.warning(em"redundant capture: $parent already accounts for $ref", pos) + report.warning(em"redundant capture: $parent already accounts for $ref in $tp", pos) case _ => /** Warn if `ann`, which is the tree of a @retains annotation, defines some elements that @@ -166,11 +169,15 @@ object CheckCaptures: def warnIfRedundantCaptureSet(ann: Tree, tpt: Tree)(using Context): Unit = var retained = retainedElems(ann).toArray for i <- 0 until retained.length do - val ref = retained(i).toCaptureRef + val refTree = retained(i) + val ref = refTree.toCaptureRef val others = for j <- 0 until retained.length if j != i yield retained(j).toCaptureRef val remaining = CaptureSet(others*) if remaining.accountsFor(ref) then - val srcTree = if ann.span.exists then ann else tpt + val srcTree = + if refTree.span.exists then refTree + else if ann.span.exists then ann + else tpt report.warning(em"redundant capture: $remaining already accounts for $ref", srcTree.srcPos) /** Attachment key for bodies of closures, provided they are values */ @@ -1192,9 +1199,12 @@ class CheckCaptures extends Recheck, SymTransformer: def postCheck(unit: tpd.Tree)(using Context): Unit = val checker = new TreeTraverser: def traverse(tree: Tree)(using Context): Unit = - traverseChildren(tree) + val lctx = tree match + case _: DefTree | _: TypeDef if tree.symbol.exists => ctx.withOwner(tree.symbol) + case _ => ctx + traverseChildren(tree)(using lctx) check(tree) - def check(tree: Tree) = tree match + def check(tree: Tree)(using Context) = tree match case _: InferredTypeTree => case tree: TypeTree if !tree.span.isZeroExtent => tree.knownType.foreachPart { tp => @@ -1253,7 +1263,7 @@ class CheckCaptures extends Recheck, SymTransformer: case _ => end check end checker - checker.traverse(unit) + checker.traverse(unit)(using ctx.withOwner(defn.RootClass)) if !ctx.reporter.errorsReported then // We dont report errors here if previous errors were reported, because other // errors often result in bad applied types, but flagging these bad types gives diff --git a/compiler/src/dotty/tools/dotc/cc/Setup.scala b/compiler/src/dotty/tools/dotc/cc/Setup.scala index dc32be5e8e53..3e9a9fe1274e 100644 --- a/compiler/src/dotty/tools/dotc/cc/Setup.scala +++ b/compiler/src/dotty/tools/dotc/cc/Setup.scala @@ -217,6 +217,11 @@ extends tpd.TreeTraverser: then CapturingType(tp, CaptureSet.universal, boxed = false) else tp + private def checkQualifiedRoots(tree: Tree)(using Context): Unit = + for case elem @ QualifiedRoot(outer) <- retainedElems(tree) do + if !ctx.owner.levelOwnerNamed(outer).exists then + report.error(em"`$outer` does not name an outer definition that represents a capture level", elem.srcPos) + private def expandAliases(using Context) = new TypeMap with FollowAliases: override def toString = "expand aliases" def apply(t: Type) = @@ -226,12 +231,13 @@ extends tpd.TreeTraverser: if t2 ne t then return t2 t match case t @ AnnotatedType(t1, ann) => - val t2 = + checkQualifiedRoots(ann.tree) + val t3 = if ann.symbol == defn.RetainsAnnot && isCapabilityClassRef(t1) then t1 else this(t1) // Don't map capture sets, since that would implicitly normalize sets that // are not well-formed. - t.derivedAnnotatedType(t2, ann) + t.derivedAnnotatedType(t3, ann) case _ => mapOverFollowingAliases(t) diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 1813e77aa4ee..4a9d4162107b 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -972,6 +972,7 @@ class Definitions { @tu lazy val CapsModule: Symbol = requiredModule("scala.caps") @tu lazy val captureRoot: TermSymbol = CapsModule.requiredValue("cap") @tu lazy val Caps_Cap: TypeSymbol = CapsModule.requiredType("Cap") + @tu lazy val Caps_capIn: TermSymbol = CapsModule.requiredMethod("capIn") @tu lazy val CapsUnsafeModule: Symbol = requiredModule("scala.caps.unsafe") @tu lazy val Caps_unsafeAssumePure: Symbol = CapsUnsafeModule.requiredMethod("unsafeAssumePure") @tu lazy val Caps_unsafeBox: Symbol = CapsUnsafeModule.requiredMethod("unsafeBox") diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index 95c7c2cb2cd9..4fc7ea4185d8 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -434,6 +434,7 @@ object StdNames { val bytes: N = "bytes" val canEqual_ : N = "canEqual" val canEqualAny : N = "canEqualAny" + val capIn: N = "capIn" val caps: N = "caps" val captureChecking: N = "captureChecking" val checkInitialized: N = "checkInitialized" diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 93e858be904d..bd5159a60931 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -1423,13 +1423,23 @@ object Parsers { case _ => None } - /** CaptureRef ::= ident | `this` + /** CaptureRef ::= ident | `this` | `cap` [`[` ident `]`] */ def captureRef(): Tree = if in.token == THIS then simpleRef() else termIdent() match - case Ident(nme.CAPTURE_ROOT) => captureRoot - case id => id + case id @ Ident(nme.CAPTURE_ROOT) => + if in.token == LBRACKET then + val ref = atSpan(id.span.start)(captureRootIn) + val qual = + inBrackets: + atSpan(in.offset): + Literal(Constant(ident().toString)) + atSpan(id.span.start)(Apply(ref, qual :: Nil)) + else + atSpan(id.span.start)(captureRoot) + case id => + id /** CaptureSet ::= `{` CaptureRef {`,` CaptureRef} `}` -- under captureChecking */ diff --git a/library/src/scala/caps.scala b/library/src/scala/caps.scala index a23de5674476..2db66e8c540d 100644 --- a/library/src/scala/caps.scala +++ b/library/src/scala/caps.scala @@ -15,6 +15,8 @@ import annotation.experimental given Cap = cap + def capIn(scope: String): Cap = () + object unsafe: extension [T](x: T) diff --git a/tests/neg-custom-args/captures/localcaps.scala b/tests/neg-custom-args/captures/localcaps.scala new file mode 100644 index 000000000000..50cbe8e0f8f9 --- /dev/null +++ b/tests/neg-custom-args/captures/localcaps.scala @@ -0,0 +1,7 @@ +class C: + def x: C^{cap[d]} = ??? // error + + def y: C^{cap[C]} = ??? // ok + private val z = (x: Int) => (c: C^{cap[z]}) => x // ok + + private val z2 = identity((x: Int) => (c: C^{cap[z2]}) => x) // error diff --git a/tests/pos-custom-args/captures/pairs.scala b/tests/pos-custom-args/captures/pairs.scala index 43488e2dde54..b78c10d30ef2 100644 --- a/tests/pos-custom-args/captures/pairs.scala +++ b/tests/pos-custom-args/captures/pairs.scala @@ -31,3 +31,18 @@ object Monomorphic: val x1c: Cap ->{c} Unit = x1 val y1 = p.snd val y1c: Cap ->{d} Unit = y1 + +object Monomorphic2: + + class Pair(x: Cap => Unit, y: Cap => Unit): + def fst: Cap^{cap[Pair]} ->{x} Unit = x + def snd: Cap^{cap[Pair]} ->{y} Unit = y + + def test(c: Cap, d: Cap) = + def f(x: Cap): Unit = if c == x then () + def g(x: Cap): Unit = if d == x then () + val p = Pair(f, g) + val x1 = p.fst + val x1c: Cap ->{c} Unit = x1 + val y1 = p.snd + val y1c: Cap ->{d} Unit = y1