diff --git a/compiler/src/dotty/tools/dotc/cc/SepCheck.scala b/compiler/src/dotty/tools/dotc/cc/SepCheck.scala index be71fe82dc72..6989ef21f081 100644 --- a/compiler/src/dotty/tools/dotc/cc/SepCheck.scala +++ b/compiler/src/dotty/tools/dotc/cc/SepCheck.scala @@ -457,14 +457,16 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser: * Also check separation via checkType within individual arguments widened to their * formal paramater types. * - * @param fn the applied function - * @param args the flattened argument lists - * @param app the entire application tree - * @param deps cross argument dependencies: maps argument trees to - * those other arguments that where mentioned by coorresponding - * formal parameters. + * @param fn the applied function + * @param args the flattened argument lists + * @param app the entire application tree + * @param deps cross argument dependencies: maps argument trees to + * those other arguments that where mentioned by coorresponding + * formal parameters. + * @param resultPeaks peaks in the result type that could interfere with the + * hidden sets of formal parameters */ - private def checkApply(fn: Tree, args: List[Tree], app: Tree, deps: collection.Map[Tree, List[Tree]])(using Context): Unit = + private def checkApply(fn: Tree, args: List[Tree], app: Tree, deps: collection.Map[Tree, List[Tree]], resultPeaks: Refs)(using Context): Unit = val (qual, fnCaptures) = methPart(fn) match case Select(qual, _) => (qual, qual.nuType.captureSet) case _ => (fn, CaptureSet.empty) @@ -475,6 +477,7 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser: i"""check separate $fn($args), fnCaptures = $fnCaptures, | formalCaptures = ${args.map(arg => CaptureSet(formalCaptures(arg)))}, | actualCaptures = ${args.map(arg => CaptureSet(captures(arg)))}, + | resultPeaks = ${resultPeaks}, | deps = ${deps.toList}""") val parts = qual :: args var reported: SimpleIdentitySet[Tree] = SimpleIdentitySet.empty @@ -519,26 +522,10 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser: currentPeaks.hidden ++ argPeaks.hidden) end for - def collectRefs(args: List[Type], res: Type) = - args.foldLeft(argCaptures(res)): (refs, arg) => - refs ++ arg.deepCaptureSet.elems - - /** The deep capture sets of all parameters of this type (if it is a function type) */ - def argCaptures(tpe: Type): Refs = tpe match - case defn.FunctionOf(args, resultType, isContextual) => - collectRefs(args, resultType) - case defn.RefinedFunctionOf(mt) => - collectRefs(mt.paramInfos, mt.resType) - case CapturingType(parent, _) => - argCaptures(parent) - case _ => - emptyRefs - - if !deps(app).isEmpty then - lazy val appPeaks = argCaptures(app.nuType).peaks + if !resultPeaks.isEmpty then lazy val partPeaks = partsWithPeaks.toMap - for arg <- deps(app) do - if arg.needsSepCheck && !partPeaks(arg).hidden.sharedWith(appPeaks).isEmpty then + for arg <- args do + if arg.needsSepCheck && !partPeaks(arg).hidden.sharedWith(resultPeaks).isEmpty then sepApplyError(fn, parts, arg, app) end checkApply @@ -816,10 +803,15 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser: * then the dependencies of an application `f(a, b, c)` of type C^{y} is the map * * [ b -> [a] - * , c -> [a, b] - * , f(a, b, c) -> [b]] + * , c -> [a, b] ] + * + * It also returns the interfering peaks of the result of the application. They are the + * peaks of argument captures and deep captures of the result function type, minus the + * those dependent on parameters. For instance, + * if `f` has the type (x: A, y: B, c: C) -> (op: () ->{b} Unit) -> List[() ->{x, y, a} Unit], its interfering + * peaks will be the peaks of `a` and `b`. */ - private def dependencies(fn: Tree, argss: List[List[Tree]], app: Tree)(using Context): collection.Map[Tree, List[Tree]] = + private def dependencies(fn: Tree, argss: List[List[Tree]], app: Tree)(using Context): (collection.Map[Tree, List[Tree]], Refs) = def isFunApply(sym: Symbol) = sym.name == nme.apply && defn.isFunctionClass(sym.owner) val mtpe = @@ -831,23 +823,47 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser: val argMap = mtpsWithArgs.toMap val deps = mutable.HashMap[Tree, List[Tree]]().withDefaultValue(Nil) + def argOfDep(dep: Capability): Option[Tree] = + dep.stripReach match + case dep: TermParamRef => + Some(argMap(dep.binder)(dep.paramNum)) + case dep: ThisType if dep.cls == fn.symbol.owner => + val Select(qual, _) = fn: @unchecked // TODO can we use fn instead? + Some(qual) + case _ => + None + def recordDeps(formal: Type, actual: Tree) = - for dep <- formal.captureSet.elems.toList do - val referred = dep.stripReach match - case dep: TermParamRef => - argMap(dep.binder)(dep.paramNum) :: Nil - case dep: ThisType if dep.cls == fn.symbol.owner => - val Select(qual, _) = fn: @unchecked // TODO can we use fn instead? - qual :: Nil - case _ => - Nil + def captures = formal.captureSet + for dep <- captures.elems.toList do + val referred = argOfDep(dep) deps(actual) ++= referred + inline def isLocalRef(x: Capability): Boolean = x.isInstanceOf[TermParamRef] + + def resultArgCaptures(tpe: Type): Refs = + def collectRefs(args: List[Type], res: Type) = + args.foldLeft(resultArgCaptures(res)): (refs, arg) => + refs ++ arg.captureSet.elems + tpe match + case defn.FunctionOf(args, resultType, isContextual) => + collectRefs(args, resultType) + case defn.RefinedFunctionOf(mt) => + collectRefs(mt.paramInfos, mt.resType) + case CapturingType(parent, refs) => + resultArgCaptures(parent) ++ tpe.boxedCaptureSet.elems + case _ => + emptyRefs + for (mt, args) <- mtpsWithArgs; (formal, arg) <- mt.paramInfos.zip(args) do recordDeps(formal, arg) - recordDeps(mtpe.finalResultType, app) + + val resultType = mtpe.finalResultType + val resultCaptures = + (resultArgCaptures(resultType) ++ resultType.deepCaptureSet.elems).filter(!isLocalRef(_)) + val resultPeaks = resultCaptures.peaks capt.println(i"deps for $app = ${deps.toList}") - deps + (deps, resultPeaks) /** Decompose an application into a function prefix and a list of argument lists. @@ -860,7 +876,8 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser: case TypeApply(fn, args) => recur(fn, argss) // skip type arguments case _ => if argss.nestedExists(_.needsSepCheck) then - checkApply(tree, argss.flatten, app, dependencies(tree, argss, app)) + val (deps, resultPeaks) = dependencies(tree, argss, app) + checkApply(tree, argss.flatten, app, deps, resultPeaks) recur(app, Nil) /** Is `tree` an application of `caps.unsafe.unsafeAssumeSeparate`? */ diff --git a/tests/neg-custom-args/captures/i23726.check b/tests/neg-custom-args/captures/i23726.check new file mode 100644 index 000000000000..8c8ac94a61e0 --- /dev/null +++ b/tests/neg-custom-args/captures/i23726.check @@ -0,0 +1,51 @@ +-- Error: tests/neg-custom-args/captures/i23726.scala:10:5 ------------------------------------------------------------- +10 | f1(a) // error, as expected + | ^ + |Separation failure: argument of type (a : Ref^) + |to a function of type (x: Ref^) -> List[() ->{a, x} Unit] + |corresponds to capture-polymorphic formal parameter x of type Ref^² + |and hides capabilities {a}. + |Some of these overlap with the captures of the function result with type List[() ->{a} Unit]. + | + | Hidden set of current argument : {a} + | Hidden footprint of current argument : {a} + | Capture set of function result : {a} + | Footprint set of function result : {a} + | The two sets overlap at : {a} + | + |where: ^ refers to a fresh root capability classified as Mutable created in value a when constructing mutable Ref + | ^² refers to a fresh root capability classified as Mutable created in method test1 when checking argument to parameter x of method apply +-- Error: tests/neg-custom-args/captures/i23726.scala:15:5 ------------------------------------------------------------- +15 | f3(b) // error + | ^ + |Separation failure: argument of type (b : Ref^) + |to a function of type (x: Ref^) -> (op: () ->{b} Unit) -> List[() ->{op} Unit] + |corresponds to capture-polymorphic formal parameter x of type Ref^² + |and hides capabilities {b}. + |Some of these overlap with the captures of the function result with type (op: () ->{b} Unit) -> List[() ->{op} Unit]. + | + | Hidden set of current argument : {b} + | Hidden footprint of current argument : {b} + | Capture set of function result : {op} + | Footprint set of function result : {op, b} + | The two sets overlap at : {b} + | + |where: ^ refers to a fresh root capability classified as Mutable created in value b when constructing mutable Ref + | ^² refers to a fresh root capability classified as Mutable created in method test1 when checking argument to parameter x of method apply +-- Error: tests/neg-custom-args/captures/i23726.scala:23:5 ------------------------------------------------------------- +23 | f7(a) // error + | ^ + |Separation failure: argument of type (a : Ref^) + |to a function of type (x: Ref^) ->{a, b} (y: List[Ref^{a, b}]) ->{a, b} Unit + |corresponds to capture-polymorphic formal parameter x of type Ref^² + |and hides capabilities {a}. + |Some of these overlap with the captures of the function prefix. + | + | Hidden set of current argument : {a} + | Hidden footprint of current argument : {a} + | Capture set of function prefix : {f7*} + | Footprint set of function prefix : {f7*, a, b} + | The two sets overlap at : {a} + | + |where: ^ refers to a fresh root capability classified as Mutable created in value a when constructing mutable Ref + | ^² refers to a fresh root capability classified as Mutable created in method test1 when checking argument to parameter x of method apply diff --git a/tests/neg-custom-args/captures/i23726.scala b/tests/neg-custom-args/captures/i23726.scala new file mode 100644 index 000000000000..fc833ef29583 --- /dev/null +++ b/tests/neg-custom-args/captures/i23726.scala @@ -0,0 +1,23 @@ +import language.experimental.captureChecking +import language.experimental.separationChecking +import caps.* +class Ref extends Mutable +def swap(a: Ref^, b: Ref^): Unit = () +def test1(): Unit = + val a = Ref() + val b = Ref() + val f1: (x: Ref^) -> List[() ->{a,x} Unit] = ??? + f1(a) // error, as expected + val f2: (x: Ref^) -> List[() ->{x} Unit] = ??? + f2(a) // ok, as expected + val f3: (x: Ref^) -> (op: () ->{b} Unit) -> List[() ->{op} Unit] = ??? + f3(a) // ok + f3(b) // error + val f4: (x: Ref^) -> (y: Ref^{x}) ->{x} Unit = ??? + f4(a) // ok + val f5: (x: Ref^) -> (y: List[Ref^{a}]) ->{} Unit = ??? + f5(a) // ok + val f6: (x: Ref^) -> (y: List[Ref^{a, b}]) ->{} Unit = ??? + f6(b) // ok + val f7: (x: Ref^) ->{a, b} (y: List[Ref^{a, b}]) ->{a, b} Unit = ??? + f7(a) // error