Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions compiler/src/dotty/tools/dotc/core/CaptureSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ case class CaptureSet(elems: CaptureSet.Refs) extends Showable:
def <:< (that: CaptureSet)(using Context): Boolean =
elems.isEmpty || elems.forall(that.accountsFor)

def map(f: CaptureRef => CaptureRef)(using Context): CaptureSet =
(empty /: elems)((cs, ref) => cs + f(ref))

def flatMap(f: CaptureRef => CaptureSet)(using Context): CaptureSet =
(empty /: elems)((cs, ref) => cs ++ f(ref))

Expand All @@ -53,6 +56,14 @@ case class CaptureSet(elems: CaptureSet.Refs) extends Showable:
((NoType: Type) /: elems) ((tp, ref) =>
if tp.exists then OrType(tp, ref, soft = false) else ref)

override def hashCode: Int = (0 /: elems) ((x, ref) => x + ref.hashCode)

override def equals(other: Any) = other match
case that: CaptureSet =>
this.elems.size == that.elems.size && this.elems.forall(that.elems.contains)
case _ =>
false

override def toString = elems.toString

override def toText(printer: Printer): Text =
Expand All @@ -66,6 +77,8 @@ object CaptureSet:
/** Used as a recursion brake */
@sharable private[core] val Pending = CaptureSet(SimpleIdentitySet.empty)

def universal(using Context) = defn.captureRootType.typeRef.singletonCaptureSet

def apply(elems: CaptureRef*)(using Context): CaptureSet =
if elems.isEmpty then empty
else CaptureSet(SimpleIdentitySet(elems.map(_.normalizedRef)*))
Expand Down Expand Up @@ -93,8 +106,8 @@ object CaptureSet:
def recur(tp: Type): CaptureSet = tp match
case tp: CaptureRef =>
tp.captureSet
case CapturingType(parent, ref) =>
recur(parent) + ref
case CapturingType(parent, refs) =>
recur(parent) ++ refs
case AppliedType(tycon, args) =>
val cs = recur(tycon)
tycon.typeParams match
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ class Definitions {
*/
@tu lazy val AnyClass: ClassSymbol = completeClass(enterCompleteClassSymbol(ScalaPackageClass, tpnme.Any, Abstract, Nil), ensureCtor = false)
def AnyType: TypeRef = AnyClass.typeRef
@tu lazy val TopType: Type = CapturingType(AnyType, captureRootType.typeRef)
@tu lazy val TopType: Type = CapturingType(AnyType, captureRootType.typeRef.singletonCaptureSet)
@tu lazy val MatchableClass: ClassSymbol = completeClass(enterCompleteClassSymbol(ScalaPackageClass, tpnme.Matchable, Trait, AnyType :: Nil), ensureCtor = false)
def MatchableType: TypeRef = MatchableClass.typeRef
@tu lazy val AnyValClass: ClassSymbol =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
if parent1 ne tp.parent then tp.derivedAnnotatedType(parent1, tp.annot) else tp
case tp: CapturingType =>
val parent1 = recur(tp.parent, fromBelow)
if parent1 ne tp.parent then tp.derivedCapturingType(parent1, tp.ref) else tp
if parent1 ne tp.parent then tp.derivedCapturingType(parent1, tp.refs) else tp
case _ =>
val tp1 = tp.dealiasKeepAnnots
if tp1 ne tp then
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2151,7 +2151,7 @@ object SymDenotations {
recur(TypeComparer.bounds(tp).hi)

case tp: CapturingType =>
tp.derivedCapturingType(recur(tp.parent), tp.ref)
tp.derivedCapturingType(recur(tp.parent), tp.refs)

case tp: TypeProxy =>
def computeTypeProxy = {
Expand Down
11 changes: 7 additions & 4 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
// under -Ycheck. Test case is i7965.scala.

case tp1: CapturingType =>
if tp2.captureSet.accountsFor(tp1.ref) then recur(tp1.parent, tp2)
if tp1.refs <:< tp2.captureSet then recur(tp1.parent, tp2)
else thirdTry
case tp1: MatchType =>
val reduced = tp1.reduced
Expand Down Expand Up @@ -818,7 +818,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
// ---------------------------
// E |- x: {x} T
//
CapturingType(tp2, defn.captureRootType.typeRef)
CapturingType(tp2, CaptureSet.universal)
case _ => tp2
isSubType(tp1.underlying.widenExpr, tp2n, approx.addLow)
}
Expand Down Expand Up @@ -2361,8 +2361,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
tp1.underlying & tp2
case tp1: AnnotatedType if !tp1.isRefining =>
tp1.underlying & tp2
case tp1: CapturingType if !tp2.captureSet.accountsFor(tp1.ref) =>
tp1.parent & tp2
case tp1: CapturingType =>
val parent1 = tp1.parent & tp2
if tp2.captureSet <:< tp1.refs then parent1
else if parent1.exists then tp1.derivedCapturingType(parent1, tp1.refs)
else NoType
case _ =>
NoType
}
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/core/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ object TypeOps:
val normed = tp.tryNormalize
if (normed.exists) normed else mapOver
case tp: CapturingType
if !ctx.mode.is(Mode.Type) && tp.parent.captureSet.accountsFor(tp.ref) =>
if !ctx.mode.is(Mode.Type) && tp.refs <:< tp.parent.captureSet =>
simplify(tp.parent, theMap)
case tp: MethodicType =>
tp // See documentation of `Types#simplified`
Expand Down Expand Up @@ -271,7 +271,7 @@ object TypeOps:
case tp1: RecType =>
return tp1.rebind(approximateOr(tp1.parent, tp2))
case tp1: CapturingType =>
return tp1.derivedCapturingType(approximateOr(tp1.parent, tp2), tp1.ref)
return tp1.derivedCapturingType(approximateOr(tp1.parent, tp2), tp1.refs)
case err: ErrorType =>
return err
case _ =>
Expand All @@ -280,7 +280,7 @@ object TypeOps:
case tp2: RecType =>
return tp2.rebind(approximateOr(tp1, tp2.parent))
case tp2: CapturingType =>
return tp2.derivedCapturingType(approximateOr(tp1, tp2.parent), tp2.ref)
return tp2.derivedCapturingType(approximateOr(tp1, tp2.parent), tp2.refs)
case err: ErrorType =>
return err
case _ =>
Expand Down
75 changes: 37 additions & 38 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ object Types {
case tp: AndOrType => tp.tp1.unusableForInference || tp.tp2.unusableForInference
case tp: LambdaType => tp.resultType.unusableForInference || tp.paramInfos.exists(_.unusableForInference)
case WildcardType(optBounds) => optBounds.unusableForInference
case CapturingType(parent, ref) => parent.unusableForInference || ref.unusableForInference
case CapturingType(parent, ref) => parent.unusableForInference
case _: ErrorType => true
case _ => false

Expand Down Expand Up @@ -1385,7 +1385,7 @@ object Types {
val tp1 = tp.parent.dealias1(keep)
if keep(tp) then tp.derivedAnnotatedType(tp1, tp.annot) else tp1
case tp: CapturingType =>
tp.derivedCapturingType(tp.parent.dealias1(keep), tp.ref)
tp.derivedCapturingType(tp.parent.dealias1(keep), tp.refs)
case tp: LazyRef =>
tp.ref.dealias1(keep)
case _ => this
Expand Down Expand Up @@ -1838,10 +1838,12 @@ object Types {
}

def capturing(ref: CaptureRef)(using Context): Type =
if captureSet.accountsFor(ref) then this else CapturingType(this, ref)
if captureSet.accountsFor(ref) then this
else CapturingType(this, ref.singletonCaptureSet)

def capturing(cs: CaptureSet)(using Context): Type =
(this /: cs.elems)(_.capturing(_))
val cs1 = cs -- captureSet
if cs1.isEmpty then this else CapturingType(this, cs)

/** The set of distinct symbols referred to by this type, after all aliases are expanded */
def coveringSet(using Context): Set[Symbol] =
Expand Down Expand Up @@ -3620,9 +3622,10 @@ object Types {
case tp: TermParamRef if tp.binder eq thisLambdaType => TrueDeps
case tp: CapturingType =>
val status1 = compute(status, tp.parent, theAcc)
tp.ref.stripTypeVar match
case tp: TermParamRef if tp.binder eq thisLambdaType => combine(status1, CaptureDeps)
case _ => status1
(status1 /: tp.refs.elems)((s, ref) => ref.stripTypeVar match
case tp: TermParamRef if tp.binder eq thisLambdaType => combine(s, CaptureDeps)
case _ => s
)
case _: ThisType | _: BoundType | NoPrefix => status
case _ =>
(if theAcc != null then theAcc else DepAcc()).foldOver(status, tp)
Expand Down Expand Up @@ -5162,39 +5165,38 @@ object Types {
unique(CachedAnnotatedType(parent, annot))
end AnnotatedType

abstract case class CapturingType(parent: Type, ref: CaptureRef) extends AnnotOrCaptType:
abstract case class CapturingType(parent: Type, refs: CaptureSet) extends AnnotOrCaptType:
override def underlying(using Context): Type = parent

def derivedCapturingType(parent: Type, ref: CaptureRef)(using Context): CapturingType =
if (parent eq this.parent) && (ref eq this.ref) then this
else CapturingType(parent, ref)

def derivedCapturing(parent: Type, capt: Type)(using Context): Type =
if (parent eq this.parent) && (capt eq this.ref) then this
else parent.capturing(capt.captureSet)
def derivedCapturingType(parent: Type, refs: CaptureSet)(using Context): CapturingType =
if (parent eq this.parent) && (refs eq this.refs) then this
else CapturingType(parent, refs)

// equals comes from case class; no matching override is needed

override def computeHash(bs: Binders): Int =
doHash(bs, parent, ref)
doHash(bs, refs, parent)
override def hashIsStable: Boolean =
parent.hashIsStable && ref.hashIsStable
parent.hashIsStable && refs.elems.forall(_.hashIsStable)

override def eql(that: Type): Boolean = that match
case that: CapturingType => (parent eq that.parent) && (ref eq that.ref)
case that: CapturingType => (parent eq that.parent) && refs.equals(that.refs)
case _ => false

override def iso(that: Any, bs: BinderPairs): Boolean = that match
case that: CapturingType => parent.equals(that.parent, bs) && ref.equals(that.ref, bs)
case that: CapturingType => parent.equals(that.parent, bs) && refs.equals(that.refs)
case _ => false

class CachedCapturingType(parent: Type, ref: CaptureRef) extends CapturingType(parent, ref)
class CachedCapturingType(parent: Type, refs: CaptureSet) extends CapturingType(parent, refs)

object CapturingType:
def apply(parent: Type, ref: CaptureRef)(using Context): CapturingType =
unique(CachedCapturingType(parent, ref.normalizedRef))
def checked(parent: Type, ref: Type)(using Context): CapturingType = ref match
case ref: CaptureRef => apply(parent, ref)
def apply(parent: Type, refs: CaptureSet)(using Context): CapturingType =
unique(CachedCapturingType(parent, refs.map(_.normalizedRef)))
def checked(parent: Type, tps: Type*)(using Context): CapturingType =
val refs: Seq[CaptureRef] = tps map {
case ref: CaptureRef => ref
}
apply(parent, CaptureSet(refs*))
end CapturingType

// Special type objects and classes -----------------------------------------------------
Expand Down Expand Up @@ -5458,8 +5460,8 @@ object Types {
tp.derivedMatchType(bound, scrutinee, cases)
protected def derivedAnnotatedType(tp: AnnotatedType, underlying: Type, annot: Annotation): Type =
tp.derivedAnnotatedType(underlying, annot)
protected def derivedCapturing(tp: CapturingType, parent: Type, capt: Type): Type =
tp.derivedCapturing(parent, capt)
protected def derivedCapturingType(tp: CapturingType, parent: Type, cs: CaptureSet): Type =
tp.derivedCapturingType(parent, cs)
protected def derivedWildcardType(tp: WildcardType, bounds: Type): Type =
tp.derivedWildcardType(bounds)
protected def derivedSkolemType(tp: SkolemType, info: Type): Type =
Expand Down Expand Up @@ -5539,8 +5541,8 @@ object Types {
if (underlying1 eq underlying) tp
else derivedAnnotatedType(tp, underlying1, mapOver(annot))

case tp @ CapturingType(parent, ref) =>
derivedCapturing(tp, this(parent), this(ref))
case tp @ CapturingType(parent, refs) =>
derivedCapturingType(tp, this(parent), refs.flatMap(this(_).captureSet))

case _: ThisType
| _: BoundType
Expand Down Expand Up @@ -5861,15 +5863,12 @@ object Types {
if (underlying.isExactlyNothing) underlying
else tp.derivedAnnotatedType(underlying, annot)
}
override protected def derivedCapturing(tp: CapturingType, parent: Type, capt: Type): Type =
capt match
override protected def derivedCapturingType(tp: CapturingType, parent: Type, cs: CaptureSet): Type =
parent match
case Range(lo, hi) =>
range(derivedCapturing(tp, parent, hi), derivedCapturing(tp, parent, lo))
case _ => parent match
case Range(lo, hi) =>
range(derivedCapturing(tp, lo, capt), derivedCapturing(tp, hi, capt))
case _ =>
tp.derivedCapturing(parent, capt)
range(derivedCapturingType(tp, lo, cs), derivedCapturingType(tp, hi, cs))
case _ =>
tp.derivedCapturingType(parent, cs)

override protected def derivedWildcardType(tp: WildcardType, bounds: Type): WildcardType =
tp.derivedWildcardType(rangeToBounds(bounds))
Expand Down Expand Up @@ -6010,8 +6009,8 @@ object Types {
case AnnotatedType(underlying, annot) =>
this(applyToAnnot(x, annot), underlying)

case CapturingType(parent, ref) =>
this(this(x, parent), ref)
case CapturingType(parent, refs) =>
(this(x, parent) /: refs.elems)(this)

case tp: ProtoType =>
tp.fold(x, this)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,8 @@ class TreePickler(pickler: TastyPickler) {
writeByte(APPLIEDtype)
withLength {
pickleType(defn.Predef_capturing.typeRef)
pickleType(tp.ref)
pickleType(tp.parent)
tp.refs.elems.foreach(pickleType(_))
}
case tpe: PolyType if richTypes =>
pickleMethodic(POLYtype, tpe, EmptyFlags)
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,8 @@ class TreeUnpickler(reader: TastyReader,
val args = until(end)(readType())
tycon match
case tycon: TypeRef if tycon.symbol == defn.Predef_capturing =>
if ctx.settings.Ycc.value then CapturingType.checked(args(1), args(0))
else args(1)
if ctx.settings.Ycc.value then CapturingType.checked(args.head, args.tail*)
else args.head
case _ =>
tycon.appliedTo(args)
case TYPEBOUNDS =>
Expand Down
37 changes: 18 additions & 19 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -900,21 +900,18 @@ object Parsers {

def followingIsCaptureSet(): Boolean =
val lookahead = in.LookaheadScanner()
def recur(): Boolean =
lookahead.isIdent && {
def skipElems(): Unit =
lookahead.nextToken()
if lookahead.isIdent then
lookahead.nextToken()
if lookahead.token == COMMA then
lookahead.nextToken()
recur()
else
lookahead.token == RBRACE && {
lookahead.nextToken()
canStartInfixTypeTokens.contains(lookahead.token)
|| lookahead.token == LBRACKET
}
}
lookahead.nextToken()
recur()
if lookahead.token == COMMA then skipElems()
skipElems()
lookahead.token == RBRACE
&& {
lookahead.nextToken()
canStartInfixTypeTokens.contains(lookahead.token)
|| lookahead.token == LBRACKET
}

/* --------- OPERAND/OPERATOR STACK --------------------------------------- */

Expand Down Expand Up @@ -1366,7 +1363,7 @@ object Parsers {
* FunTypeArgs ::= InfixType
* | `(' [ [ ‘[using]’ ‘['erased'] FunArgType {`,' FunArgType } ] `)'
* | '(' [ ‘[using]’ ‘['erased'] TypedFunParam {',' TypedFunParam } ')'
* CaptureSet ::= `{` CaptureRef {`,` CaptureRef} `}`
* CaptureSet ::= `{` [CaptureRef {`,` CaptureRef}] `}`
* CaptureRef ::= Ident
*/
def typ(): Tree = {
Expand Down Expand Up @@ -1467,10 +1464,12 @@ object Parsers {
else { accept(TLARROW); typ() }
}
else if in.token == LBRACE && followingIsCaptureSet() then
val refs = inBraces { commaSeparated(captureRef) }
val t = typ()
val captured = refs.reduce(InfixOp(_, Ident(tpnme.raw.BAR), _))
AppliedTypeTree(TypeTree(defn.Predef_capturing.typeRef), captured :: t :: Nil)
in.nextToken()
val captured =
if in.token == RBRACE then TypeTree(defn.NothingType)
else commaSeparated(captureRef).reduce(InfixOp(_, Ident(tpnme.raw.BAR), _))
accept(RBRACE)
AppliedTypeTree(TypeTree(defn.Predef_capturing.typeRef), captured :: typ() :: Nil)
else if (in.token == INDENT) enclosed(INDENT, typ())
else infixType()

Expand Down
9 changes: 6 additions & 3 deletions compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,11 @@ class PlainPrinter(_ctx: Context) extends Printer {
keywordStr(" match ") ~ "{" ~ casesText ~ "}" ~
(" <: " ~ toText(bound) provided !bound.isAny)
}.close
case CapturingType(parent, ref) =>
case CapturingType(parent, refs) =>
if Config.printCaptureSetsAsPrefix then
changePrec(GlobalPrec)("{" ~ toTextCaptureRef(ref) ~ "} " ~ toText(parent))
changePrec(GlobalPrec)(toTextCaptureSet(refs) ~ " " ~ toText(parent))
else
changePrec(InfixPrec)(toText(parent) ~ " retains " ~ toTextCaptureRef(ref))
changePrec(InfixPrec)(toText(parent) ~ " retains " ~ toText(refs.toRetainsTypeArg))
case tp: PreviousErrorType if ctx.settings.XprintTypes.value =>
"<error>" // do not print previously reported error message because they may try to print this error type again recuresevely
case tp: ErrorType =>
Expand Down Expand Up @@ -346,6 +346,9 @@ class PlainPrinter(_ctx: Context) extends Printer {
case tp: SingletonType => toTextRef(tp)
case _ => toText(tp)

def toTextCaptureSet(cs: CaptureSet): Text =
"{" ~ Text(cs.elems.toList.map(toTextCaptureRef), ", ") ~ "}"

protected def isOmittablePrefix(sym: Symbol): Boolean =
defn.unqualifiedOwnerTypes.exists(_.symbol == sym) || isEmptyPrefix(sym)

Expand Down
Loading