Skip to content

Commit

Permalink
Carry and check universal capability from parents correctly (#20004)
Browse files Browse the repository at this point in the history
Fix #18857
This PR checks universal capability from parent classes properly.
  • Loading branch information
odersky committed Apr 27, 2024
2 parents 8825b07 + f6529c4 commit 837ed3a
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 14 deletions.
4 changes: 0 additions & 4 deletions compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,6 @@ extension (tp: Type)
case _ =>
false

def isCapabilityClassRef(using Context) = tp.dealiasKeepAnnots match
case _: TypeRef | _: AppliedType => tp.typeSymbol.hasAnnotation(defn.CapabilityAnnot)
case _ => false

/** Drop @retains annotations everywhere */
def dropAllRetains(using Context): Type = // TODO we should drop retains from inferred types before unpickling
val tm = new TypeMap:
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,8 @@ class CheckCaptures extends Recheck, SymTransformer:
*/
def addParamArgRefinements(core: Type, initCs: CaptureSet): (Type, CaptureSet) =
var refined: Type = core
var allCaptures: CaptureSet = initCs
var allCaptures: CaptureSet = if setup.isCapabilityClassRef(core)
then CaptureSet.universal else initCs
for (getterName, argType) <- mt.paramNames.lazyZip(argTypes) do
val getter = cls.info.member(getterName).suchThat(_.is(ParamAccessor)).symbol
if getter.termRef.isTracked && !getter.is(Private) then
Expand Down
35 changes: 28 additions & 7 deletions compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ trait SetupAPI:
def setupUnit(tree: Tree, recheckDef: DefRecheck)(using Context): Unit
def isPreCC(sym: Symbol)(using Context): Boolean
def postCheck()(using Context): Unit
def isCapabilityClassRef(tp: Type)(using Context): Boolean

object Setup:

Expand Down Expand Up @@ -67,6 +68,31 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
&& !sym.owner.is(CaptureChecked)
&& !defn.isFunctionSymbol(sym.owner)

private val capabilityClassMap = new util.HashMap[Symbol, Boolean]

/** Check if the class is capability, which means:
* 1. the class has a capability annotation,
* 2. or at least one of its parent type has universal capability.
*/
def isCapabilityClassRef(tp: Type)(using Context): Boolean = tp.dealiasKeepAnnots match
case _: TypeRef | _: AppliedType =>
val sym = tp.classSymbol
def checkSym: Boolean =
sym.hasAnnotation(defn.CapabilityAnnot)
|| sym.info.parents.exists(hasUniversalCapability)
sym.isClass && capabilityClassMap.getOrElseUpdate(sym, checkSym)
case _ => false

private def hasUniversalCapability(tp: Type)(using Context): Boolean = tp.dealiasKeepAnnots match
case CapturingType(parent, refs) =>
refs.isUniversal || hasUniversalCapability(parent)
case AnnotatedType(parent, ann) =>
if ann.symbol.isRetains then
try ann.tree.toCaptureSet.isUniversal || hasUniversalCapability(parent)
catch case ex: IllegalCaptureRef => false
else hasUniversalCapability(parent)
case tp => isCapabilityClassRef(tp)

private def fluidify(using Context) = new TypeMap with IdempotentCaptRefMap:
def apply(t: Type): Type = t match
case t: MethodType =>
Expand Down Expand Up @@ -269,12 +295,6 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
CapturingType(fntpe, cs, boxed = false)
else fntpe

/** Map references to capability classes C to C^ */
private def expandCapabilityClass(tp: Type): Type =
if tp.isCapabilityClassRef
then CapturingType(tp, defn.expandedUniversalSet, boxed = false)
else tp

private def recur(t: Type): Type = normalizeCaptures(mapOver(t))

def apply(t: Type) =
Expand All @@ -297,7 +317,8 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
case t: TypeVar =>
this(t.underlying)
case t =>
if t.isCapabilityClassRef
// Map references to capability classes C to C^
if isCapabilityClassRef(t)
then CapturingType(t, defn.expandedUniversalSet, boxed = false)
else recur(t)
end expandAliases
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.util.control.NonFatal
import config.Config
import reporting.*
import collection.mutable
import cc.{CapturingType, derivedCapturingType}
import cc.{CapturingType, derivedCapturingType, stripCapturing}

import scala.annotation.internal.sharable
import scala.compiletime.uninitialized
Expand Down Expand Up @@ -2225,7 +2225,7 @@ object SymDenotations {
tp match {
case tp @ TypeRef(prefix, _) =>
def foldGlb(bt: Type, ps: List[Type]): Type = ps match {
case p :: ps1 => foldGlb(bt & recur(p), ps1)
case p :: ps1 => foldGlb(bt & recur(p.stripCapturing), ps1)
case _ => bt
}

Expand Down
15 changes: 15 additions & 0 deletions tests/neg-custom-args/captures/extending-cap-classes.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import annotation.capability

class C1
@capability class C2 extends C1
class C3 extends C2

def test =
val x1: C1 = new C1
val x2: C1 = new C2 // error
val x3: C1 = new C3 // error

val y1: C2 = new C2
val y2: C2 = new C3

val z1: C3 = new C3
30 changes: 30 additions & 0 deletions tests/neg-custom-args/captures/extending-impure-function.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
class F1 extends (Int => Unit) {
def apply(x: Int): Unit = ()
}

class F2 extends (Int -> Unit) {
def apply(x: Int): Unit = ()
}

def test =
val x1 = new (Int => Unit) {
def apply(x: Int): Unit = ()
}

val x2: Int -> Unit = new (Int => Unit) { // error
def apply(x: Int): Unit = ()
}

val x3: Int -> Unit = new (Int -> Unit) {
def apply(x: Int): Unit = ()
}

val y1: Int => Unit = new F1
val y2: Int -> Unit = new F1 // error
val y3: Int => Unit = new F2
val y4: Int -> Unit = new F2

val z1 = () => ()
val z2: () -> Unit = () => ()
val z3: () -> Unit = z1
val z4: () => Unit = () => ()

0 comments on commit 837ed3a

Please sign in to comment.