diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala index a9e4ce42087d..a0d874007728 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala @@ -236,6 +236,22 @@ extension (tp: Type) * (2) all covariant occurrences of cap replaced by `x*`, provided there * are no occurrences in `T` at other variances. (1) is standard, whereas * (2) is new. + * + * For (2), multiple-flipped covariant occurrences of cap won't be replaced. + * In other words, + * + * - For xs: List[File^] ==> List[File^{xs*}], the cap is replaced; + * - while f: [R] -> (op: File^ => R) -> R remains unchanged. + * + * Without this restriction, the signature of functions like withFile: + * + * (path: String) -> [R] -> (op: File^ => R) -> R + * + * could be refined to + * + * (path: String) -> [R] -> (op: File^{withFile*} => R) -> R + * + * which is clearly unsound. * * Why is this sound? Covariant occurrences of cap must represent capabilities * that are reachable from `x`, so they are included in the meaning of `{x*}`. @@ -243,28 +259,43 @@ extension (tp: Type) * occurrences of cap are allowed in instance types of type variables. */ def withReachCaptures(ref: Type)(using Context): Type = - object narrowCaps extends TypeMap: + class CheckContraCaps extends TypeTraverser: var ok = true - def apply(t: Type) = t.dealias match - case t1 @ CapturingType(p, cs) if cs.isUniversal => - if variance > 0 then - t1.derivedCapturingType(apply(p), ref.reach.singletonCaptureSet) - else - ok = false - t - case _ => t match - case t @ CapturingType(p, cs) => - t.derivedCapturingType(apply(p), cs) // don't map capture set variables - case t => - mapOver(t) + def traverse(t: Type): Unit = + if ok then + t match + case CapturingType(_, cs) if cs.isUniversal && variance <= 0 => + ok = false + case _ => + traverseChildren(t) + + object narrowCaps extends TypeMap: + /** Has the variance been flipped at this point? */ + private var isFlipped: Boolean = false + + def apply(t: Type) = + val saved = isFlipped + try + if variance <= 0 then isFlipped = true + t.dealias match + case t1 @ CapturingType(p, cs) if cs.isUniversal && !isFlipped => + t1.derivedCapturingType(apply(p), ref.reach.singletonCaptureSet) + case _ => t match + case t @ CapturingType(p, cs) => + t.derivedCapturingType(apply(p), cs) // don't map capture set variables + case t => + mapOver(t) + finally isFlipped = saved ref match case ref: CaptureRef if ref.isTrackableRef => - val tp1 = narrowCaps(tp) - if narrowCaps.ok then + val checker = new CheckContraCaps + checker.traverse(tp) + if checker.ok then + val tp1 = narrowCaps(tp) if tp1 ne tp then capt.println(i"narrow $tp of $ref to $tp1") tp1 else - capt.println(i"cannot narrow $tp of $ref to $tp1") + capt.println(i"cannot narrow $tp of $ref") tp case _ => tp diff --git a/tests/neg-custom-args/captures/refine-reach-shallow.scala b/tests/neg-custom-args/captures/refine-reach-shallow.scala new file mode 100644 index 000000000000..9f4b28ce52e3 --- /dev/null +++ b/tests/neg-custom-args/captures/refine-reach-shallow.scala @@ -0,0 +1,18 @@ +import language.experimental.captureChecking +trait IO +def test1(): Unit = + val f: IO^ => IO^ = x => x + val g: IO^ => IO^{f*} = f // error +def test2(): Unit = + val f: [R] -> (IO^ => R) -> R = ??? + val g: [R] -> (IO^{f*} => R) -> R = f // error +def test3(): Unit = + val f: [R] -> (IO^ -> R) -> R = ??? + val g: [R] -> (IO^{f*} -> R) -> R = f // error +def test4(): Unit = + val xs: List[IO^] = ??? + val ys: List[IO^{xs*}] = xs // ok +def test5(): Unit = + val f: [R] -> (IO^ -> R) -> IO^ = ??? + val g: [R] -> (IO^ -> R) -> IO^{f*} = f // ok + val h: [R] -> (IO^{f*} -> R) -> IO^ = f // error diff --git a/tests/neg-custom-args/captures/refine-withFile.scala b/tests/neg-custom-args/captures/refine-withFile.scala new file mode 100644 index 000000000000..823b62711d05 --- /dev/null +++ b/tests/neg-custom-args/captures/refine-withFile.scala @@ -0,0 +1,8 @@ +import language.experimental.captureChecking + +trait File +val useFile: [R] -> (path: String) -> (op: File^ -> R) -> R = ??? +def main(): Unit = + val f: [R] -> (path: String) -> (op: File^ -> R) -> R = useFile + val g: [R] -> (path: String) -> (op: File^{f*} -> R) -> R = f // error + val leaked = g[File^{f*}]("test")(f => f) // boom