From 590016e6fc392b13e57c5c485d3d1a5c070ca71b Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Tue, 6 Jul 2021 13:29:32 +0200 Subject: [PATCH 1/3] Make CapturingTypes take sets instead of single references --- .../dotty/tools/dotc/core/CaptureSet.scala | 17 ++++- .../dotty/tools/dotc/core/Definitions.scala | 2 +- .../tools/dotc/core/OrderingConstraint.scala | 2 +- .../tools/dotc/core/SymDenotations.scala | 2 +- .../dotty/tools/dotc/core/TypeComparer.scala | 11 ++- .../src/dotty/tools/dotc/core/TypeOps.scala | 6 +- .../src/dotty/tools/dotc/core/Types.scala | 75 +++++++++---------- .../tools/dotc/core/tasty/TreePickler.scala | 2 +- .../tools/dotc/core/tasty/TreeUnpickler.scala | 4 +- .../tools/dotc/printing/PlainPrinter.scala | 9 ++- .../src/dotty/tools/dotc/sbt/ExtractAPI.scala | 8 +- .../tools/dotc/typer/CheckCaptures.scala | 13 ++-- .../tools/dotc/typer/ExpandCaptures.scala | 11 ++- .../dotty/tools/dotc/typer/Inferencing.scala | 2 +- .../dotty/tools/dotc/typer/TypeAssigner.scala | 24 +++--- tests/neg-custom-args/captures/boxmap.check | 2 +- 16 files changed, 104 insertions(+), 86 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/CaptureSet.scala b/compiler/src/dotty/tools/dotc/core/CaptureSet.scala index b280ddfe083c..83ab79905c5a 100644 --- a/compiler/src/dotty/tools/dotc/core/CaptureSet.scala +++ b/compiler/src/dotty/tools/dotc/core/CaptureSet.scala @@ -40,6 +40,9 @@ case class CaptureSet(elems: CaptureSet.Refs) extends Showable: def <:< (that: CaptureSet)(using Context): Boolean = elems.isEmpty || elems.forall(that.accountsFor) + def map(f: CaptureRef => CaptureRef)(using Context): CaptureSet = + (empty /: elems)((cs, ref) => cs + f(ref)) + def flatMap(f: CaptureRef => CaptureSet)(using Context): CaptureSet = (empty /: elems)((cs, ref) => cs ++ f(ref)) @@ -53,6 +56,14 @@ case class CaptureSet(elems: CaptureSet.Refs) extends Showable: ((NoType: Type) /: elems) ((tp, ref) => if tp.exists then OrType(tp, ref, soft = false) else ref) + override def hashCode: Int = (0 /: elems) ((x, ref) => x + ref.hashCode) + + override def equals(other: Any) = other match + case that: CaptureSet => + this.elems.size == that.elems.size && this.elems.forall(that.elems.contains) + case _ => + false + override def toString = elems.toString override def toText(printer: Printer): Text = @@ -66,6 +77,8 @@ object CaptureSet: /** Used as a recursion brake */ @sharable private[core] val Pending = CaptureSet(SimpleIdentitySet.empty) + def universal(using Context) = defn.captureRootType.typeRef.singletonCaptureSet + def apply(elems: CaptureRef*)(using Context): CaptureSet = if elems.isEmpty then empty else CaptureSet(SimpleIdentitySet(elems.map(_.normalizedRef)*)) @@ -93,8 +106,8 @@ object CaptureSet: def recur(tp: Type): CaptureSet = tp match case tp: CaptureRef => tp.captureSet - case CapturingType(parent, ref) => - recur(parent) + ref + case CapturingType(parent, refs) => + recur(parent) ++ refs case AppliedType(tycon, args) => val cs = recur(tycon) tycon.typeParams match diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 9bd1bedc7e04..9f0f2b944db8 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -264,7 +264,7 @@ class Definitions { */ @tu lazy val AnyClass: ClassSymbol = completeClass(enterCompleteClassSymbol(ScalaPackageClass, tpnme.Any, Abstract, Nil), ensureCtor = false) def AnyType: TypeRef = AnyClass.typeRef - @tu lazy val TopType: Type = CapturingType(AnyType, captureRootType.typeRef) + @tu lazy val TopType: Type = CapturingType(AnyType, captureRootType.typeRef.singletonCaptureSet) @tu lazy val MatchableClass: ClassSymbol = completeClass(enterCompleteClassSymbol(ScalaPackageClass, tpnme.Matchable, Trait, AnyType :: Nil), ensureCtor = false) def MatchableType: TypeRef = MatchableClass.typeRef @tu lazy val AnyValClass: ClassSymbol = diff --git a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala index 4302a9f54fb8..04a135a16bcd 100644 --- a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala @@ -331,7 +331,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds, if parent1 ne tp.parent then tp.derivedAnnotatedType(parent1, tp.annot) else tp case tp: CapturingType => val parent1 = recur(tp.parent, fromBelow) - if parent1 ne tp.parent then tp.derivedCapturingType(parent1, tp.ref) else tp + if parent1 ne tp.parent then tp.derivedCapturingType(parent1, tp.refs) else tp case _ => val tp1 = tp.dealiasKeepAnnots if tp1 ne tp then diff --git a/compiler/src/dotty/tools/dotc/core/SymDenotations.scala b/compiler/src/dotty/tools/dotc/core/SymDenotations.scala index 7c1aee33ca43..326724df845e 100644 --- a/compiler/src/dotty/tools/dotc/core/SymDenotations.scala +++ b/compiler/src/dotty/tools/dotc/core/SymDenotations.scala @@ -2151,7 +2151,7 @@ object SymDenotations { recur(TypeComparer.bounds(tp).hi) case tp: CapturingType => - tp.derivedCapturingType(recur(tp.parent), tp.ref) + tp.derivedCapturingType(recur(tp.parent), tp.refs) case tp: TypeProxy => def computeTypeProxy = { diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index a03b1797383f..c1273f4261e3 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -490,7 +490,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling // under -Ycheck. Test case is i7965.scala. case tp1: CapturingType => - if tp2.captureSet.accountsFor(tp1.ref) then recur(tp1.parent, tp2) + if tp1.refs <:< tp2.captureSet then recur(tp1.parent, tp2) else thirdTry case tp1: MatchType => val reduced = tp1.reduced @@ -818,7 +818,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling // --------------------------- // E |- x: {x} T // - CapturingType(tp2, defn.captureRootType.typeRef) + CapturingType(tp2, CaptureSet.universal) case _ => tp2 isSubType(tp1.underlying.widenExpr, tp2n, approx.addLow) } @@ -2361,8 +2361,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling tp1.underlying & tp2 case tp1: AnnotatedType if !tp1.isRefining => tp1.underlying & tp2 - case tp1: CapturingType if !tp2.captureSet.accountsFor(tp1.ref) => - tp1.parent & tp2 + case tp1: CapturingType => + val parent1 = tp1.parent & tp2 + if tp2.captureSet <:< tp1.refs then parent1 + else if parent1.exists then tp1.derivedCapturingType(parent1, tp1.refs) + else NoType case _ => NoType } diff --git a/compiler/src/dotty/tools/dotc/core/TypeOps.scala b/compiler/src/dotty/tools/dotc/core/TypeOps.scala index dfdc16a8a6d2..47bdf3b11ad4 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeOps.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeOps.scala @@ -169,7 +169,7 @@ object TypeOps: val normed = tp.tryNormalize if (normed.exists) normed else mapOver case tp: CapturingType - if !ctx.mode.is(Mode.Type) && tp.parent.captureSet.accountsFor(tp.ref) => + if !ctx.mode.is(Mode.Type) && tp.refs <:< tp.parent.captureSet => simplify(tp.parent, theMap) case tp: MethodicType => tp // See documentation of `Types#simplified` @@ -271,7 +271,7 @@ object TypeOps: case tp1: RecType => return tp1.rebind(approximateOr(tp1.parent, tp2)) case tp1: CapturingType => - return tp1.derivedCapturingType(approximateOr(tp1.parent, tp2), tp1.ref) + return tp1.derivedCapturingType(approximateOr(tp1.parent, tp2), tp1.refs) case err: ErrorType => return err case _ => @@ -280,7 +280,7 @@ object TypeOps: case tp2: RecType => return tp2.rebind(approximateOr(tp1, tp2.parent)) case tp2: CapturingType => - return tp2.derivedCapturingType(approximateOr(tp1, tp2.parent), tp2.ref) + return tp2.derivedCapturingType(approximateOr(tp1, tp2.parent), tp2.refs) case err: ErrorType => return err case _ => diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index f4cb1ca07041..f0d853180fea 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -373,7 +373,7 @@ object Types { case tp: AndOrType => tp.tp1.unusableForInference || tp.tp2.unusableForInference case tp: LambdaType => tp.resultType.unusableForInference || tp.paramInfos.exists(_.unusableForInference) case WildcardType(optBounds) => optBounds.unusableForInference - case CapturingType(parent, ref) => parent.unusableForInference || ref.unusableForInference + case CapturingType(parent, ref) => parent.unusableForInference case _: ErrorType => true case _ => false @@ -1385,7 +1385,7 @@ object Types { val tp1 = tp.parent.dealias1(keep) if keep(tp) then tp.derivedAnnotatedType(tp1, tp.annot) else tp1 case tp: CapturingType => - tp.derivedCapturingType(tp.parent.dealias1(keep), tp.ref) + tp.derivedCapturingType(tp.parent.dealias1(keep), tp.refs) case tp: LazyRef => tp.ref.dealias1(keep) case _ => this @@ -1838,10 +1838,12 @@ object Types { } def capturing(ref: CaptureRef)(using Context): Type = - if captureSet.accountsFor(ref) then this else CapturingType(this, ref) + if captureSet.accountsFor(ref) then this + else CapturingType(this, ref.singletonCaptureSet) def capturing(cs: CaptureSet)(using Context): Type = - (this /: cs.elems)(_.capturing(_)) + val cs1 = cs -- captureSet + if cs1.isEmpty then this else CapturingType(this, cs) /** The set of distinct symbols referred to by this type, after all aliases are expanded */ def coveringSet(using Context): Set[Symbol] = @@ -3620,9 +3622,10 @@ object Types { case tp: TermParamRef if tp.binder eq thisLambdaType => TrueDeps case tp: CapturingType => val status1 = compute(status, tp.parent, theAcc) - tp.ref.stripTypeVar match - case tp: TermParamRef if tp.binder eq thisLambdaType => combine(status1, CaptureDeps) - case _ => status1 + (status1 /: tp.refs.elems)((s, ref) => ref.stripTypeVar match + case tp: TermParamRef if tp.binder eq thisLambdaType => combine(s, CaptureDeps) + case _ => s + ) case _: ThisType | _: BoundType | NoPrefix => status case _ => (if theAcc != null then theAcc else DepAcc()).foldOver(status, tp) @@ -5162,39 +5165,38 @@ object Types { unique(CachedAnnotatedType(parent, annot)) end AnnotatedType - abstract case class CapturingType(parent: Type, ref: CaptureRef) extends AnnotOrCaptType: + abstract case class CapturingType(parent: Type, refs: CaptureSet) extends AnnotOrCaptType: override def underlying(using Context): Type = parent - def derivedCapturingType(parent: Type, ref: CaptureRef)(using Context): CapturingType = - if (parent eq this.parent) && (ref eq this.ref) then this - else CapturingType(parent, ref) - - def derivedCapturing(parent: Type, capt: Type)(using Context): Type = - if (parent eq this.parent) && (capt eq this.ref) then this - else parent.capturing(capt.captureSet) + def derivedCapturingType(parent: Type, refs: CaptureSet)(using Context): CapturingType = + if (parent eq this.parent) && (refs eq this.refs) then this + else CapturingType(parent, refs) // equals comes from case class; no matching override is needed override def computeHash(bs: Binders): Int = - doHash(bs, parent, ref) + doHash(bs, refs, parent) override def hashIsStable: Boolean = - parent.hashIsStable && ref.hashIsStable + parent.hashIsStable && refs.elems.forall(_.hashIsStable) override def eql(that: Type): Boolean = that match - case that: CapturingType => (parent eq that.parent) && (ref eq that.ref) + case that: CapturingType => (parent eq that.parent) && refs.equals(that.refs) case _ => false override def iso(that: Any, bs: BinderPairs): Boolean = that match - case that: CapturingType => parent.equals(that.parent, bs) && ref.equals(that.ref, bs) + case that: CapturingType => parent.equals(that.parent, bs) && refs.equals(that.refs) case _ => false - class CachedCapturingType(parent: Type, ref: CaptureRef) extends CapturingType(parent, ref) + class CachedCapturingType(parent: Type, refs: CaptureSet) extends CapturingType(parent, refs) object CapturingType: - def apply(parent: Type, ref: CaptureRef)(using Context): CapturingType = - unique(CachedCapturingType(parent, ref.normalizedRef)) - def checked(parent: Type, ref: Type)(using Context): CapturingType = ref match - case ref: CaptureRef => apply(parent, ref) + def apply(parent: Type, refs: CaptureSet)(using Context): CapturingType = + unique(CachedCapturingType(parent, refs.map(_.normalizedRef))) + def checked(parent: Type, tps: Type*)(using Context): CapturingType = + val refs: Seq[CaptureRef] = tps map { + case ref: CaptureRef => ref + } + apply(parent, CaptureSet(refs*)) end CapturingType // Special type objects and classes ----------------------------------------------------- @@ -5458,8 +5460,8 @@ object Types { tp.derivedMatchType(bound, scrutinee, cases) protected def derivedAnnotatedType(tp: AnnotatedType, underlying: Type, annot: Annotation): Type = tp.derivedAnnotatedType(underlying, annot) - protected def derivedCapturing(tp: CapturingType, parent: Type, capt: Type): Type = - tp.derivedCapturing(parent, capt) + protected def derivedCapturingType(tp: CapturingType, parent: Type, cs: CaptureSet): Type = + tp.derivedCapturingType(parent, cs) protected def derivedWildcardType(tp: WildcardType, bounds: Type): Type = tp.derivedWildcardType(bounds) protected def derivedSkolemType(tp: SkolemType, info: Type): Type = @@ -5539,8 +5541,8 @@ object Types { if (underlying1 eq underlying) tp else derivedAnnotatedType(tp, underlying1, mapOver(annot)) - case tp @ CapturingType(parent, ref) => - derivedCapturing(tp, this(parent), this(ref)) + case tp @ CapturingType(parent, refs) => + derivedCapturingType(tp, this(parent), refs.flatMap(this(_).captureSet)) case _: ThisType | _: BoundType @@ -5861,15 +5863,12 @@ object Types { if (underlying.isExactlyNothing) underlying else tp.derivedAnnotatedType(underlying, annot) } - override protected def derivedCapturing(tp: CapturingType, parent: Type, capt: Type): Type = - capt match + override protected def derivedCapturingType(tp: CapturingType, parent: Type, cs: CaptureSet): Type = + parent match case Range(lo, hi) => - range(derivedCapturing(tp, parent, hi), derivedCapturing(tp, parent, lo)) - case _ => parent match - case Range(lo, hi) => - range(derivedCapturing(tp, lo, capt), derivedCapturing(tp, hi, capt)) - case _ => - tp.derivedCapturing(parent, capt) + range(derivedCapturingType(tp, lo, cs), derivedCapturingType(tp, hi, cs)) + case _ => + tp.derivedCapturingType(parent, cs) override protected def derivedWildcardType(tp: WildcardType, bounds: Type): WildcardType = tp.derivedWildcardType(rangeToBounds(bounds)) @@ -6010,8 +6009,8 @@ object Types { case AnnotatedType(underlying, annot) => this(applyToAnnot(x, annot), underlying) - case CapturingType(parent, ref) => - this(this(x, parent), ref) + case CapturingType(parent, refs) => + (this(x, parent) /: refs.elems)(this) case tp: ProtoType => tp.fold(x, this) diff --git a/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala b/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala index f52720aa94e6..2a0c6151b525 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala @@ -292,8 +292,8 @@ class TreePickler(pickler: TastyPickler) { writeByte(APPLIEDtype) withLength { pickleType(defn.Predef_capturing.typeRef) - pickleType(tp.ref) pickleType(tp.parent) + tp.refs.elems.foreach(pickleType(_)) } case tpe: PolyType if richTypes => pickleMethodic(POLYtype, tpe, EmptyFlags) diff --git a/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala b/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala index 632569298461..f90fa174bb04 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala @@ -361,8 +361,8 @@ class TreeUnpickler(reader: TastyReader, val args = until(end)(readType()) tycon match case tycon: TypeRef if tycon.symbol == defn.Predef_capturing => - if ctx.settings.Ycc.value then CapturingType.checked(args(1), args(0)) - else args(1) + if ctx.settings.Ycc.value then CapturingType.checked(args.head, args.tail*) + else args.head case _ => tycon.appliedTo(args) case TYPEBOUNDS => diff --git a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala index 6b705f1ac590..9cba05c09684 100644 --- a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala @@ -188,11 +188,11 @@ class PlainPrinter(_ctx: Context) extends Printer { keywordStr(" match ") ~ "{" ~ casesText ~ "}" ~ (" <: " ~ toText(bound) provided !bound.isAny) }.close - case CapturingType(parent, ref) => + case CapturingType(parent, refs) => if Config.printCaptureSetsAsPrefix then - changePrec(GlobalPrec)("{" ~ toTextCaptureRef(ref) ~ "} " ~ toText(parent)) + changePrec(GlobalPrec)(toTextCaptureSet(refs) ~ " " ~ toText(parent)) else - changePrec(InfixPrec)(toText(parent) ~ " retains " ~ toTextCaptureRef(ref)) + changePrec(InfixPrec)(toText(parent) ~ " retains " ~ toText(refs.toRetainsTypeArg)) case tp: PreviousErrorType if ctx.settings.XprintTypes.value => "" // do not print previously reported error message because they may try to print this error type again recuresevely case tp: ErrorType => @@ -346,6 +346,9 @@ class PlainPrinter(_ctx: Context) extends Printer { case tp: SingletonType => toTextRef(tp) case _ => toText(tp) + def toTextCaptureSet(cs: CaptureSet): Text = + "{" ~ Text(cs.elems.toList.map(toTextCaptureRef), ", ") ~ "}" + protected def isOmittablePrefix(sym: Symbol): Boolean = defn.unqualifiedOwnerTypes.exists(_.symbol == sym) || isEmptyPrefix(sym) diff --git a/compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala b/compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala index 74c48e4ace8d..ec8aa1a4b7d4 100644 --- a/compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala +++ b/compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala @@ -175,7 +175,7 @@ private class ExtractAPICollector(using Context) extends ThunkHolder { private val byNameMarker = marker("ByName") private val matchMarker = marker("Match") private val superMarker = marker("Super") - private val holdsMarker = marker("Holds") + private val retainsMarker = marker("Retains") /** Extract the API representation of a source file */ def apiSource(tree: Tree): Seq[api.ClassLike] = { @@ -521,9 +521,9 @@ private class ExtractAPICollector(using Context) extends ThunkHolder { case SuperType(thistpe, supertpe) => val s = combineApiTypes(apiType(thistpe), apiType(supertpe)) withMarker(s, superMarker) - case CapturingType(parent, ref) => - val s = combineApiTypes(apiType(parent), apiType(ref)) - withMarker(s, holdsMarker) + case CapturingType(parent, refs) => + val s = combineApiTypes((apiType(parent) :: refs.elems.toList.map(apiType))*) + withMarker(s, retainsMarker) case _ => { internalError(i"Unhandled type $tp of class ${tp.getClass}") Constants.emptyType diff --git a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala index b6ca719d19d4..0fc4e52b939f 100644 --- a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala @@ -99,12 +99,13 @@ class CheckCaptures extends RefineTypes: def checkWellFormed(whole: Type, pos: SrcPos)(using Context): Unit = def checkRelativeVariance(mt: MethodType) = new TypeTraverser: def traverse(tp: Type): Unit = tp match - case CapturingType(parent, ref) => - ref.stripTypeVar match - case ref @ TermParamRef(`mt`, _) if variance <= 0 => - val direction = if variance < 0 then "contra" else "in" - report.error(em"captured reference $ref appears ${direction}variantly in type $whole", pos) - case _ => + case CapturingType(parent, refs) => + for ref <- refs.elems do + ref.stripTypeVar match + case ref @ TermParamRef(`mt`, _) if variance <= 0 => + val direction = if variance < 0 then "contra" else "in" + report.error(em"captured reference $ref appears ${direction}variantly in type $whole", pos) + case _ => traverse(parent) case _ => traverseChildren(tp) diff --git a/compiler/src/dotty/tools/dotc/typer/ExpandCaptures.scala b/compiler/src/dotty/tools/dotc/typer/ExpandCaptures.scala index 84be74f3d428..e220ad6207e8 100644 --- a/compiler/src/dotty/tools/dotc/typer/ExpandCaptures.scala +++ b/compiler/src/dotty/tools/dotc/typer/ExpandCaptures.scala @@ -109,8 +109,7 @@ object ExpandCaptures: CaptureSet.empty def wrapImplied(tpe: Type) = - if canAdd then - (tpe /: (outerCaptures ++ nestedCaptures(tpe)).elems)(CapturingType(_, _)) + if canAdd then tpe.capturing(outerCaptures ++ nestedCaptures(tpe)) else tpe def reportOverlap(declared: CaptureSet, implied: CaptureSet): Unit = @@ -120,10 +119,10 @@ object ExpandCaptures: pos) tpe match - case tpe @ CapturingType(parent, ref) => - reportOverlap(tpe.captureSet, outerCaptures ++ nestedCaptures(parent)) - val parent1 = addImplied(parent, bound, outerCaptures + ref, canAdd = false, pos) - tpe.derivedCapturingType(parent1, ref) + case tpe @ CapturingType(parent, refs) => + reportOverlap(refs, outerCaptures ++ nestedCaptures(parent)) + val parent1 = addImplied(parent, bound, outerCaptures ++ refs, canAdd = false, pos) + tpe.derivedCapturingType(parent1, refs) case FunctionType(tparams, params, body) => val newParamCaptures = paramCaptures(params) -- bound -- outerCaptures if newParamCaptures.nonEmpty && !tpe.isInstanceOf[RefinedType] then diff --git a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala index 17db403bcfdf..ebfecdf0d211 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala @@ -525,7 +525,7 @@ object Inferencing { case tp: RecType => tp.derivedRecType(captureWildcards(tp.parent)) case tp: LazyRef => captureWildcards(tp.ref) case tp: AnnotatedType => tp.derivedAnnotatedType(captureWildcards(tp.parent), tp.annot) - case tp: CapturingType => tp.derivedCapturingType(captureWildcards(tp.parent), tp.ref) + case tp: CapturingType => tp.derivedCapturingType(captureWildcards(tp.parent), tp.refs) case _ => tp } } diff --git a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala index 0d2a909f0ebb..981d4b88e3fe 100644 --- a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala +++ b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala @@ -186,33 +186,33 @@ trait TypeAssigner { else errorType(ex"$whatCanNot be accessed as a member of $pre$where.$whyNot", pos) def processAppliedType(tree: untpd.Tree, tp: Type)(using Context): Type = - def captType(tp: Type, refs: Type): Type = refs match - case ref: NamedType => + def include(cs: CaptureSet, tp: Type): CaptureSet = tp match + case ref: CaptureRef => if ref.isTracked then - if tp.captureSet.accountsFor(ref) then - report.warning(em"redundant capture: $tp with capture set ${tp.captureSet} already contains $ref with capture set ${ref.captureSet}", tree.srcPos) - CapturingType(tp, ref) + if cs.accountsFor(ref) then + report.warning(em"redundant capture: $cs already accounts for $ref", tree.srcPos) + cs + ref else val reason = if ref.canBeTracked then "its capture set is empty" else "it is not a parameter or a local variable" report.error(em"$ref cannot be tracked since $reason", tree.srcPos) - tp - case OrType(refs1, refs2) => - captType(captType(tp, refs1), refs2) + cs + case OrType(tp1, tp2) => + include(include(cs, tp1), tp2) case _ => - report.error(em"$refs is not a legal type for a capture set", tree.srcPos) - tp + report.error(em"$tp is not a legal type for a capture set", tree.srcPos) + cs tp match case AppliedType(tycon, args) => val constr = tycon.typeSymbol if constr == defn.andType then AndType(args(0), args(1)) else if constr == defn.orType then OrType(args(0), args(1), soft = false) else if constr == defn.Predef_retainsType then - if ctx.settings.Ycc.value then captType(args(0), args(1)) + if ctx.settings.Ycc.value then CapturingType(args(0), include(CaptureSet.empty, args(1))) else args(0) else if constr == defn.Predef_capturing then - if ctx.settings.Ycc.value then captType(args(1), args(0)) + if ctx.settings.Ycc.value then CapturingType(args(1), include(CaptureSet.empty, args(0))) else args(1) else tp case _ => tp diff --git a/tests/neg-custom-args/captures/boxmap.check b/tests/neg-custom-args/captures/boxmap.check index da1c38b4a887..b7f8ef24cbd6 100644 --- a/tests/neg-custom-args/captures/boxmap.check +++ b/tests/neg-custom-args/captures/boxmap.check @@ -1,7 +1,7 @@ -- [E007] Type Mismatch Error: tests/neg-custom-args/captures/boxmap.scala:15:2 ---------------------------------------- 15 | () => b[Box[B]]((x: A) => box(f(x))) // error | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - | Found: {f} {b} () => Box[B] + | Found: {b, f} () => Box[B] | Required: {B} () => Box[B] | | where: B is a type in method lazymap with bounds <: Top From 986ad6691dce2f50ca79563014f21541638df20c Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Tue, 6 Jul 2021 14:28:26 +0200 Subject: [PATCH 2/3] Allow given capture sets to be empty Needed to override capture inference to state that a type does not capture anything. --- .../dotty/tools/dotc/parsing/Parsers.scala | 37 +++++++++---------- .../dotty/tools/dotc/typer/TypeAssigner.scala | 3 +- .../pos-custom-args/captures/cc-expand.scala | 7 +++- 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 9e0351c4d50b..537bc8ff2793 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -900,21 +900,18 @@ object Parsers { def followingIsCaptureSet(): Boolean = val lookahead = in.LookaheadScanner() - def recur(): Boolean = - lookahead.isIdent && { + def skipElems(): Unit = + lookahead.nextToken() + if lookahead.isIdent then lookahead.nextToken() - if lookahead.token == COMMA then - lookahead.nextToken() - recur() - else - lookahead.token == RBRACE && { - lookahead.nextToken() - canStartInfixTypeTokens.contains(lookahead.token) - || lookahead.token == LBRACKET - } - } - lookahead.nextToken() - recur() + if lookahead.token == COMMA then skipElems() + skipElems() + lookahead.token == RBRACE + && { + lookahead.nextToken() + canStartInfixTypeTokens.contains(lookahead.token) + || lookahead.token == LBRACKET + } /* --------- OPERAND/OPERATOR STACK --------------------------------------- */ @@ -1366,7 +1363,7 @@ object Parsers { * FunTypeArgs ::= InfixType * | `(' [ [ ‘[using]’ ‘['erased'] FunArgType {`,' FunArgType } ] `)' * | '(' [ ‘[using]’ ‘['erased'] TypedFunParam {',' TypedFunParam } ')' - * CaptureSet ::= `{` CaptureRef {`,` CaptureRef} `}` + * CaptureSet ::= `{` [CaptureRef {`,` CaptureRef}] `}` * CaptureRef ::= Ident */ def typ(): Tree = { @@ -1467,10 +1464,12 @@ object Parsers { else { accept(TLARROW); typ() } } else if in.token == LBRACE && followingIsCaptureSet() then - val refs = inBraces { commaSeparated(captureRef) } - val t = typ() - val captured = refs.reduce(InfixOp(_, Ident(tpnme.raw.BAR), _)) - AppliedTypeTree(TypeTree(defn.Predef_capturing.typeRef), captured :: t :: Nil) + in.nextToken() + val captured = + if in.token == RBRACE then TypeTree(defn.NothingType) + else commaSeparated(captureRef).reduce(InfixOp(_, Ident(tpnme.raw.BAR), _)) + accept(RBRACE) + AppliedTypeTree(TypeTree(defn.Predef_capturing.typeRef), captured :: typ() :: Nil) else if (in.token == INDENT) enclosed(INDENT, typ()) else infixType() diff --git a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala index 981d4b88e3fe..f3b86c9a5576 100644 --- a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala +++ b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala @@ -188,7 +188,8 @@ trait TypeAssigner { def processAppliedType(tree: untpd.Tree, tp: Type)(using Context): Type = def include(cs: CaptureSet, tp: Type): CaptureSet = tp match case ref: CaptureRef => - if ref.isTracked then + if ref.isExactlyNothing then cs + else if ref.isTracked then if cs.accountsFor(ref) then report.warning(em"redundant capture: $cs already accounts for $ref", tree.srcPos) cs + ref diff --git a/tests/pos-custom-args/captures/cc-expand.scala b/tests/pos-custom-args/captures/cc-expand.scala index bb94f8fd3ba1..7b0b34af6ce3 100644 --- a/tests/pos-custom-args/captures/cc-expand.scala +++ b/tests/pos-custom-args/captures/cc-expand.scala @@ -6,7 +6,7 @@ object Test: class CTC type CT = CTC retains * - def test(ct: CT, dt: CT) = + def test[X <: {*} Any, Y <: {*} Any](ct: CT, dt: CT) = def x0: A => {ct} B = ??? @@ -18,4 +18,7 @@ object Test: def x5: A => (x: B retains ct.type) => () => C retains dt.type = ??? def x6: A => (x: B retains ct.type) => (() => C retains dt.type) retains x.type | dt.type = ??? - def x7: A => (x: B retains ct.type) => (() => C retains dt.type) retains x.type = ??? \ No newline at end of file + def x7: A => (x: B retains ct.type) => (() => C retains dt.type) retains x.type = ??? + + def x8: X => Y = ??? + def x9: {} X => Y = ??? From b4b5bfdba3218b1b61e603f17a6ae09ef72d6205 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Tue, 6 Jul 2021 18:53:53 +0200 Subject: [PATCH 3/3] Convert test to new syntax --- tests/pos-custom-args/captures/cc-expand.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/pos-custom-args/captures/cc-expand.scala b/tests/pos-custom-args/captures/cc-expand.scala index 7b0b34af6ce3..4f5d7a82118c 100644 --- a/tests/pos-custom-args/captures/cc-expand.scala +++ b/tests/pos-custom-args/captures/cc-expand.scala @@ -11,14 +11,14 @@ object Test: def x0: A => {ct} B = ??? def x1: A => B retains ct.type = ??? - def x2: A => B => C retains ct.type = ??? - def x3: A => () => B => C retains ct.type = ??? + def x2: A => B => {ct} C = ??? + def x3: A => () => B => {ct} C = ??? - def x4: (x: A retains ct.type) => B => C = ??? + def x4: (x: {ct} A) => B => C = ??? - def x5: A => (x: B retains ct.type) => () => C retains dt.type = ??? - def x6: A => (x: B retains ct.type) => (() => C retains dt.type) retains x.type | dt.type = ??? - def x7: A => (x: B retains ct.type) => (() => C retains dt.type) retains x.type = ??? + def x5: A => (x: {ct} B) => () => {dt} C = ??? + def x6: A => (x: {ct} B) => {x, dt} (() => {dt} C) = ??? + def x7: A => (x: {ct} B) => {x} () => {dt} C = ??? def x8: X => Y = ??? def x9: {} X => Y = ???