diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index ba7b9132e88a..66f67c5a12c0 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -8,7 +8,7 @@ import Symbols.*, StdNames.*, Trees.*, ContextOps.* import Decorators.* import Annotations.Annotation import NameKinds.{UniqueName, ContextBoundParamName, ContextFunctionParamName, DefaultGetterName, WildcardParamName} -import typer.{Namer, Checking} +import typer.{Namer, Checking, ErrorReporting} import util.{Property, SourceFile, SourcePosition, SrcPos, Chars} import config.{Feature, Config} import config.Feature.{sourceVersion, migrateTo3, enabled} @@ -213,9 +213,10 @@ object desugar { def valDef(vdef0: ValDef)(using Context): Tree = val vdef @ ValDef(_, tpt, rhs) = vdef0 val valName = normalizeName(vdef, tpt).asTermName + val tpt1 = desugarQualifiedTypes(tpt, valName) var mods1 = vdef.mods - val vdef1 = cpy.ValDef(vdef)(name = valName).withMods(mods1) + val vdef1 = cpy.ValDef(vdef)(name = valName, tpt = tpt1).withMods(mods1) if isSetterNeeded(vdef) then val setterParam = makeSyntheticParameter(tpt = SetterParamTree().watching(vdef)) @@ -232,6 +233,14 @@ object desugar { else vdef1 end valDef + def caseDef(cdef: CaseDef)(using Context): CaseDef = + if Feature.qualifiedTypesEnabled then + val CaseDef(pat, guard, body) = cdef + val pat1 = DesugarQualifiedTypesInPatternMap().transform(pat) + cpy.CaseDef(cdef)(pat1, guard, body) + else + cdef + def mapParamss(paramss: List[ParamClause]) (mapTypeParam: TypeDef => TypeDef) (mapTermParam: ValDef => ValDef)(using Context): List[ParamClause] = @@ -2347,6 +2356,8 @@ object desugar { case PatDef(mods, pats, tpt, rhs) => val pats1 = if (tpt.isEmpty) pats else pats map (Typed(_, tpt)) flatTree(pats1 map (makePatDef(tree, mods, _, rhs))) + case QualifiedTypeTree(parent, paramName, qualifier) => + qualifiedType(parent, paramName.getOrElse(nme.WILDCARD), qualifier, tree.span) case ext: ExtMethods => Block(List(ext), syntheticUnitLiteral.withSpan(ext.span)) case f: FunctionWithMods if f.hasErasedParams => makeFunctionWithValDefs(f, pt) @@ -2525,4 +2536,51 @@ object desugar { collect(tree) buf.toList } + + /** Desugar subtrees that are `QualifiedTypeTree`s using `outerParamName` as + * the qualified parameter name. + */ + private def desugarQualifiedTypes(tpt: Tree, outerParamName: TermName)(using Context): Tree = + def transform(tree: Tree): Tree = + tree match + case QualifiedTypeTree(parent, None, qualifier) => + qualifiedType(transform(parent), outerParamName, qualifier, tree.span) + case QualifiedTypeTree(parent, paramName, qualifier) => + cpy.QualifiedTypeTree(tree)(transform(parent), paramName, qualifier) + case TypeApply(fn, args) => + cpy.TypeApply(tree)(transform(fn), args) + case AppliedTypeTree(fn, args) => + cpy.AppliedTypeTree(tree)(transform(fn), args) + case InfixOp(left, op, right) => + cpy.InfixOp(tree)(transform(left), op, transform(right)) + case Parens(arg) => + cpy.Parens(tree)(transform(arg)) + case _ => + tree + + if Feature.qualifiedTypesEnabled then + trace(i"desugar qualified types in pattern: $tpt", Printers.qualifiedTypes): + transform(tpt) + else + tpt + + private class DesugarQualifiedTypesInPatternMap extends UntypedTreeMap: + override def transform(tree: Tree)(using Context): Tree = + tree match + case Typed(ident @ Ident(name: TermName), tpt) => + cpy.Typed(tree)(ident, desugarQualifiedTypes(tpt, name)) + case _ => + super.transform(tree) + + /** Returns the annotated type used to represent the qualified type with the + * given components: + * `parent @qualified[parent]((paramName: parent) => qualifier)`. + */ + def qualifiedType(parent: Tree, paramName: TermName, qualifier: Tree, span: Span)(using Context): Tree = + val param = makeParameter(paramName, parent, EmptyModifiers) // paramName: parent + val predicate = WildcardFunction(List(param), qualifier) // (paramName: parent) => qualifier + val qualifiedAnnot = scalaAnnotationDot(nme.qualified) + val annot = Apply(TypeApply(qualifiedAnnot, List(parent)), predicate).withSpan(span) // @qualified[parent](predicate) + Annotated(parent, annot).withSpan(span) // parent @qualified[parent](predicate) + } diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index eddc0c675e9a..494f07288865 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -1439,6 +1439,19 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { def unapply(ts: List[Tree]): Option[List[Tree]] = if ts.nonEmpty && ts.head.isType then Some(ts) else None + + /** An extractor for trees that are constant values. */ + object ConstantTree: + def unapply(tree: Tree)(using Context): Option[Constant] = + tree match + case Inlined(_, Nil, expr) => unapply(expr) + case Typed(expr, _) => unapply(expr) + case Literal(c) if c.tag == Constants.NullTag => Some(c) + case _ => + tree.tpe.widenTermRefExpr.normalized.simplified match + case ConstantType(c) => Some(c) + case _ => None + /** Split argument clauses into a leading type argument clause if it exists and * remaining clauses */ diff --git a/compiler/src/dotty/tools/dotc/ast/untpd.scala b/compiler/src/dotty/tools/dotc/ast/untpd.scala index 17dbb5bff213..85f2a874b877 100644 --- a/compiler/src/dotty/tools/dotc/ast/untpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/untpd.scala @@ -156,6 +156,13 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { */ case class CapturesAndResult(refs: List[Tree], parent: Tree)(implicit @constructorOnly src: SourceFile) extends TypTree + /** `{ x: parent with qualifier }` if `paramName == Some(x)`, + * `parent with qualifier` otherwise. + * + * Only relevant under `qualifiedTypes`. + */ + case class QualifiedTypeTree(parent: Tree, paramName: Option[TermName], qualifier: Tree)(implicit @constructorOnly src: SourceFile) extends TypTree + /** A type tree appearing somewhere in the untyped DefDef of a lambda, it will be typed using `tpFun`. * * @param isResult Is this the result type of the lambda? This is handled specially in `Namer#valOrDefDefSig`. @@ -466,7 +473,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { def New(tpt: Tree, argss: List[List[Tree]])(using Context): Tree = ensureApplied(argss.foldLeft(makeNew(tpt))(Apply(_, _))) - /** A new expression with constrictor and possibly type arguments. See + /** A new expression with constructor and possibly type arguments. See * `New(tpt, argss)` for details. */ def makeNew(tpt: Tree)(using Context): Tree = { @@ -732,6 +739,10 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { case tree: CapturesAndResult if (refs eq tree.refs) && (parent eq tree.parent) => tree case _ => finalize(tree, untpd.CapturesAndResult(refs, parent)) + def QualifiedTypeTree(tree: Tree)(parent: Tree, paramName: Option[TermName], qualifier: Tree)(using Context): Tree = tree match + case tree: QualifiedTypeTree if (parent eq tree.parent) && (paramName eq tree.paramName) && (qualifier eq tree.qualifier) => tree + case _ => finalize(tree, untpd.QualifiedTypeTree(parent, paramName, qualifier)(using tree.source)) + def TypedSplice(tree: Tree)(splice: tpd.Tree)(using Context): ProxyTree = tree match { case tree: TypedSplice if splice `eq` tree.splice => tree case _ => finalize(tree, untpd.TypedSplice(splice)(using ctx)) @@ -795,6 +806,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { cpy.MacroTree(tree)(transform(expr)) case CapturesAndResult(refs, parent) => cpy.CapturesAndResult(tree)(transform(refs), transform(parent)) + case QualifiedTypeTree(parent, paramName, qualifier) => + cpy.QualifiedTypeTree(tree)(transform(parent), paramName, transform(qualifier)) case _ => super.transformMoreCases(tree) } @@ -854,6 +867,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { this(x, expr) case CapturesAndResult(refs, parent) => this(this(x, refs), parent) + case QualifiedTypeTree(parent, paramName, qualifier) => + this(this(x, parent), qualifier) case _ => super.foldMoreCases(x, tree) } diff --git a/compiler/src/dotty/tools/dotc/config/Feature.scala b/compiler/src/dotty/tools/dotc/config/Feature.scala index 02bdb16ae217..d4539ca1db96 100644 --- a/compiler/src/dotty/tools/dotc/config/Feature.scala +++ b/compiler/src/dotty/tools/dotc/config/Feature.scala @@ -33,6 +33,7 @@ object Feature: val pureFunctions = experimental("pureFunctions") val captureChecking = experimental("captureChecking") val separationChecking = experimental("separationChecking") + val qualifiedTypes = experimental("qualifiedTypes") val into = experimental("into") val modularity = experimental("modularity") val quotedPatternsWithPolymorphicFunctions = experimental("quotedPatternsWithPolymorphicFunctions") @@ -64,6 +65,7 @@ object Feature: (pureFunctions, "Enable pure functions for capture checking"), (captureChecking, "Enable experimental capture checking"), (separationChecking, "Enable experimental separation checking (requires captureChecking)"), + (qualifiedTypes, "Enable experimental qualified types"), (into, "Allow into modifier on parameter types"), (modularity, "Enable experimental modularity features"), (packageObjectValues, "Enable experimental package objects as values"), @@ -150,6 +152,10 @@ object Feature: if ctx.run != null then ctx.run.nn.ccEnabledSomewhere else enabledBySetting(captureChecking) + /** Is qualifiedTypes enabled for this compilation unit? */ + def qualifiedTypesEnabled(using Context) = + enabledBySetting(qualifiedTypes) + def sourceVersionSetting(using Context): SourceVersion = SourceVersion.valueOf(ctx.settings.source.value) diff --git a/compiler/src/dotty/tools/dotc/config/Printers.scala b/compiler/src/dotty/tools/dotc/config/Printers.scala index 4c66e1cdf833..2ace7b1f402a 100644 --- a/compiler/src/dotty/tools/dotc/config/Printers.scala +++ b/compiler/src/dotty/tools/dotc/config/Printers.scala @@ -51,6 +51,7 @@ object Printers { val overload = noPrinter val patmatch = noPrinter val pickling = noPrinter + val qualifiedTypes = noPrinter val quotePickling = noPrinter val plugins = noPrinter val recheckr = noPrinter diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index dbe1602e2d82..f8c8836b297c 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -625,9 +625,11 @@ class Definitions { @tu lazy val Int_/ : Symbol = IntClass.requiredMethod(nme.DIV, List(IntType)) @tu lazy val Int_* : Symbol = IntClass.requiredMethod(nme.MUL, List(IntType)) @tu lazy val Int_== : Symbol = IntClass.requiredMethod(nme.EQ, List(IntType)) + @tu lazy val Int_!= : Symbol = IntClass.requiredMethod(nme.NE, List(IntType)) @tu lazy val Int_>= : Symbol = IntClass.requiredMethod(nme.GE, List(IntType)) @tu lazy val Int_<= : Symbol = IntClass.requiredMethod(nme.LE, List(IntType)) @tu lazy val Int_> : Symbol = IntClass.requiredMethod(nme.GT, List(IntType)) + @tu lazy val Int_< : Symbol = IntClass.requiredMethod(nme.LT, List(IntType)) @tu lazy val LongType: TypeRef = valueTypeRef("scala.Long", java.lang.Long.TYPE, LongEnc, nme.specializedTypeNames.Long) def LongClass(using Context): ClassSymbol = LongType.symbol.asClass @tu lazy val Long_+ : Symbol = LongClass.requiredMethod(nme.PLUS, List(LongType)) @@ -670,6 +672,7 @@ class Definitions { @tu lazy val StringClass: ClassSymbol = requiredClass("java.lang.String") def StringType: Type = StringClass.typeRef @tu lazy val StringModule: Symbol = StringClass.linkedClass + @tu lazy val String_== : TermSymbol = enterMethod(StringClass, nme.EQ, methOfAnyRef(BooleanType), Final) @tu lazy val String_+ : TermSymbol = enterMethod(StringClass, nme.raw.PLUS, methOfAny(StringType), Final) @tu lazy val String_valueOf_Object: Symbol = StringModule.info.member(nme.valueOf).suchThat(_.info.firstParamTypes match { case List(pt) => pt.isAny || pt.stripNull().isAnyRef @@ -1048,6 +1051,7 @@ class Definitions { @tu lazy val DeprecatedAnnot: ClassSymbol = requiredClass("scala.deprecated") @tu lazy val DeprecatedOverridingAnnot: ClassSymbol = requiredClass("scala.deprecatedOverriding") @tu lazy val DeprecatedInheritanceAnnot: ClassSymbol = requiredClass("scala.deprecatedInheritance") + @tu lazy val QualifiedAnnot: ClassSymbol = requiredClass("scala.annotation.qualified") @tu lazy val ImplicitAmbiguousAnnot: ClassSymbol = requiredClass("scala.annotation.implicitAmbiguous") @tu lazy val ImplicitNotFoundAnnot: ClassSymbol = requiredClass("scala.annotation.implicitNotFound") @tu lazy val InferredDepFunAnnot: ClassSymbol = requiredClass("scala.caps.internal.inferredDepFun") diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index 323c59a5711d..786324e439af 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -589,6 +589,7 @@ object StdNames { val productElementName: N = "productElementName" val productIterator: N = "productIterator" val productPrefix: N = "productPrefix" + val qualified : N = "qualified" val quotes : N = "quotes" val raw_ : N = "raw" val rd: N = "rd" diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 1d1c497b2196..88c5488b4a98 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -27,6 +27,7 @@ import Capabilities.Capability import NameKinds.WildcardParamName import MatchTypes.isConcrete import scala.util.boundary, boundary.break +import qualified_types.{QualifiedType, QualifiedTypes} /** Provides methods to compare types. */ @@ -886,6 +887,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling println(i"assertion failed while compare captured $tp1 <:< $tp2") throw ex compareCapturing || fourthTry + case QualifiedType(parent2, qualifier2) => + recur(tp1, parent2) && QualifiedTypes.typeImplies(tp1, qualifier2, qualifierSolver()) case tp2: AnnotatedType if tp2.isRefining => (tp1.derivesAnnotWith(tp2.annot.sameAnnotation) || tp1.isBottomType) && recur(tp1, tp2.parent) @@ -3306,6 +3309,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling protected def explainingTypeComparer(short: Boolean) = ExplainingTypeComparer(comparerContext, short) protected def matchReducer = MatchReducer(comparerContext) + protected def qualifierSolver() = qualified_types.QualifierSolver(using comparerContext) private def inSubComparer[T, Cmp <: TypeComparer](comparer: Cmp)(op: Cmp => T): T = val saved = myInstance @@ -3971,7 +3975,7 @@ class ExplainingTypeComparer(initctx: Context, short: Boolean) extends TypeCompa lastForwardGoal = null override def traceIndented[T](str: String)(op: => T): T = - val str1 = str.replace('\n', ' ') + val str1 = str if short && str1 == lastForwardGoal then op // repeated goal, skip for clarity else @@ -4040,5 +4044,9 @@ class ExplainingTypeComparer(initctx: Context, short: Boolean) extends TypeCompa super.subCaptures(refs1, refs2, vs) } + override def qualifierSolver() = + val traceIndented0 = [T] => (message: String) => traceIndented[T](message) + qualified_types.ExplainingQualifierSolver(traceIndented0)(using comparerContext) + def lastTrace(header: String): String = header + { try b.toString finally b.clear() } } diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 97a172966701..b12d856d5da5 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -41,7 +41,7 @@ import compiletime.uninitialized import cc.* import CaptureSet.IdentityCaptRefMap import Capabilities.* - +import qualified_types.{QualifiedType, QualifiedAnnotation} import scala.annotation.internal.sharable import scala.annotation.threadUnsafe @@ -56,7 +56,7 @@ object Types extends TypeUtils { * The principal subclasses and sub-objects are as follows: * * ```none - * Type -+- ProxyType --+- NamedType ----+--- TypeRef + * Type -+- TypeProxy --+- NamedType ----+--- TypeRef * | | \ * | +- SingletonType-+-+- TermRef * | | | @@ -191,9 +191,10 @@ object Types extends TypeUtils { /** Is this type a (possibly refined, applied, aliased or annotated) type reference * to the given type symbol? - * @sym The symbol to compare to. It must be a class symbol or abstract type. + * @param sym The symbol to compare to. It must be a class symbol or abstract type. * It makes no sense for it to be an alias type because isRef would always * return false in that case. + * @param skipRefined If true, skip refinements, annotated types and applied types. */ def isRef(sym: Symbol, skipRefined: Boolean = true)(using Context): Boolean = this match { case this1: TypeRef => @@ -211,7 +212,7 @@ object Types extends TypeUtils { else this1.underlying.isRef(sym, skipRefined) case this1: TypeVar => this1.instanceOpt.isRef(sym, skipRefined) - case this1: AnnotatedType => + case this1: AnnotatedType if (!this1.isRefining || skipRefined) => this1.parent.isRef(sym, skipRefined) case _ => false } @@ -1615,6 +1616,7 @@ object Types extends TypeUtils { def apply(tp: Type) = /*trace(i"deskolemize($tp) at $variance", show = true)*/ tp match { case tp: SkolemType => range(defn.NothingType, atVariance(1)(apply(tp.info))) + case QualifiedType(_, _) => tp case _ => mapOver(tp) } } @@ -2150,7 +2152,7 @@ object Types extends TypeUtils { /** Is `this` isomorphic to `that`, assuming pairs of matching binders `bs`? * It is assumed that `this.ne(that)`. */ - protected def iso(that: Any, bs: BinderPairs): Boolean = this.equals(that) + def iso(that: Any, bs: BinderPairs): Boolean = this.equals(that) /** Equality used for hash-consing; uses `eq` on all recursive invocations, * except where a BindingType is involved. The latter demand a deep isomorphism check. @@ -3547,7 +3549,7 @@ object Types extends TypeUtils { case _ => false } - override protected def iso(that: Any, bs: BinderPairs) = that match + override def iso(that: Any, bs: BinderPairs) = that match case that: AndType => tp1.equals(that.tp1, bs) && tp2.equals(that.tp2, bs) case _ => false } @@ -3701,7 +3703,7 @@ object Types extends TypeUtils { case _ => false } - override protected def iso(that: Any, bs: BinderPairs) = that match + override def iso(that: Any, bs: BinderPairs) = that match case that: OrType => tp1.equals(that.tp1, bs) && tp2.equals(that.tp2, bs) && isSoft == that.isSoft case _ => false } @@ -5033,7 +5035,7 @@ object Types extends TypeUtils { * anymore, or NoType if the variable can still be further constrained or a provisional * instance type in the constraint can be retracted. */ - private[core] def permanentInst = inst + def permanentInst = inst private[core] def setPermanentInst(tp: Type): Unit = inst = tp if tp.exists && owningState != null then @@ -6269,6 +6271,8 @@ object Types extends TypeUtils { tp.derivedAnnotatedType(underlying, annot) protected def derivedCapturingType(tp: Type, parent: Type, refs: CaptureSet): Type = tp.derivedCapturingType(parent, refs) + protected def derivedENodeParamRef(tp: qualified_types.ENodeParamRef, index: Int, underlying: Type): Type = + tp.derivedENodeParamRef(index, underlying) protected def derivedWildcardType(tp: WildcardType, bounds: Type): Type = tp.derivedWildcardType(bounds) protected def derivedSkolemType(tp: SkolemType, info: Type): Type = @@ -6481,6 +6485,9 @@ object Types extends TypeUtils { case tp: JavaArrayType => derivedJavaArrayType(tp, this(tp.elemType)) + case tp: qualified_types.ENodeParamRef => + derivedENodeParamRef(tp, tp.index, this(tp.underlying)) + case _ => tp } diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 3a68703e4734..1191ebdc38a8 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -448,6 +448,13 @@ object Parsers { finally inMatchPattern = saved } + private var inQualifiedType = false + private def fromWithinQualifiedType[T](body: => T): T = + val saved = inQualifiedType + inQualifiedType = true + try body + finally inQualifiedType = saved + private var staged = StageKind.None def withinStaged[T](kind: StageKind)(op: => T): T = { val saved = staged @@ -1646,6 +1653,7 @@ object Parsers { * | TypTypeParamClause ‘=>>’ Type * | FunParamClause ‘=>>’ Type * | MatchType + * | QualifiedType2 -- under qualifiedTypes * | InfixType * FunType ::= (MonoFunType | PolyFunType) * MonoFunType ::= FunTypeArgs (‘=>’ | ‘?=>’) Type @@ -1656,6 +1664,11 @@ object Parsers { * | `(' [ FunArgType {`,' FunArgType } ] `)' * | '(' [ TypedFunParam {',' TypedFunParam } ')' * MatchType ::= InfixType `match` <<< TypeCaseClauses >>> + * QualifiedType2 ::= InfixType `with` PostfixExpr + * IntoType ::= [‘into’] IntoTargetType + * | ‘( IntoType ‘)’ + * IntoTargetType ::= Type + * | FunTypeArgs (‘=>’ | ‘?=>’) IntoType */ def typ(inContextBound: Boolean = false): Tree = val start = in.offset @@ -1715,6 +1728,8 @@ object Parsers { functionRest(t :: Nil) case MATCH => matchType(t) + case WITH if in.featureEnabled(Feature.qualifiedTypes) => + qualifiedTypeShort(t) case FORSOME => syntaxError(ExistentialTypesNoLongerSupported()) t @@ -1849,6 +1864,7 @@ object Parsers { def funParamClauses(): List[List[ValDef]] = if in.token == LPAREN then funParamClause() :: funParamClauses() else Nil + /** InfixType ::= RefinedType {id [nl] RefinedType} * | RefinedType `^` -- under captureChecking */ @@ -1908,7 +1924,7 @@ object Parsers { def withType(): Tree = withTypeRest(annotType()) def withTypeRest(t: Tree): Tree = - if in.token == WITH then + if in.token == WITH && !in.featureEnabled(Feature.qualifiedTypes) then val withOffset = in.offset in.nextToken() if in.token == LBRACE || in.token == INDENT then @@ -2060,6 +2076,7 @@ object Parsers { * | ‘(’ ArgTypes ‘)’ * | ‘(’ NamesAndTypes ‘)’ * | Refinement + * | QualifiedType -- under qualifiedTypes * | TypeSplice -- deprecated syntax (since 3.0.0) * | SimpleType1 TypeArgs * | SimpleType1 `#' id @@ -2070,7 +2087,10 @@ object Parsers { makeTupleOrParens(inParensWithCommas(argTypes(namedOK = false, wildOK = true, tupleOK = true))) } else if in.token == LBRACE then - atSpan(in.offset) { RefinedTypeTree(EmptyTree, refinement(indentOK = false)) } + if in.featureEnabled(Feature.qualifiedTypes) && in.lookahead.token == IDENTIFIER then + qualifiedType() + else + atSpan(in.offset) { RefinedTypeTree(EmptyTree, refinement(indentOK = false)) } else if (isSplice) splice(isType = true) else @@ -2234,6 +2254,30 @@ object Parsers { else inBraces(refineStatSeq()) + /** QualifiedType ::= `{` Ident `:` Type `with` Block `}` + */ + def qualifiedType(): Tree = + val startOffset = in.offset + accept(LBRACE) + val id = ident() + accept(COLONfollow) + val tp = fromWithinQualifiedType(typ()) + accept(WITH) + val qualifier = block(simplify = true) + accept(RBRACE) + QualifiedTypeTree(tp, Some(id), qualifier).withSpan(Span(startOffset, qualifier.span.end)) + + /** `with` PostfixExpr + */ + def qualifiedTypeShort(t: Tree): Tree = + if inQualifiedType then + t + else + accept(WITH) + val qualifier = postfixExpr() + QualifiedTypeTree(t, None, qualifier).withSpan(Span(t.span.start, qualifier.span.end)) + + /** TypeBounds ::= [`>:' TypeBound ] [`<:' TypeBound ] * TypeBound ::= Type * | CaptureSet -- under captureChecking @@ -2310,7 +2354,12 @@ object Parsers { def typeDependingOn(location: Location): Tree = if location.inParens then typ() - else if location.inPattern then rejectWildcardType(refinedType()) + else if location.inPattern then + val t = rejectWildcardType(refinedType()) + if in.featureEnabled(Feature.qualifiedTypes) && in.token == WITH then + qualifiedTypeShort(t) + else + t else infixType() /* ----------- EXPRESSIONS ------------------------------------------------ */ @@ -3219,10 +3268,11 @@ object Parsers { if (isIdent(nme.raw.BAR)) { in.nextToken(); pattern1(location) :: patternAlts(location) } else Nil - /** Pattern1 ::= PatVar `:` RefinedType - * | [‘-’] integerLiteral `:` RefinedType - * | [‘-’] floatingPointLiteral `:` RefinedType - * | Pattern2 + /** Pattern1 ::= PatVar `:` QualifiedType3 + * | [‘-’] integerLiteral `:` QualifiedType3 + * | [‘-’] floatingPointLiteral `:` QualifiedType3 + * | Pattern2 + * QualifiedType3 ::= RefinedType [`with` PostfixExpr] */ def pattern1(location: Location = Location.InPattern): Tree = val p = pattern2(location) diff --git a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala index 3a0611e74ec1..9fa2ae160a51 100644 --- a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala @@ -469,6 +469,11 @@ class PlainPrinter(_ctx: Context) extends Printer { "<" ~ reprStr ~ ":" ~ toText(tp.info) ~ ">" else reprStr + case qualified_types.ENodeParamRef(index, underlying) => + if ctx.settings.XprintTypes.value then + "<" ~ "eparam" ~ index.toString ~ ":" ~ toText(underlying) ~ ">" + else + "eparam" ~ index.toString } } diff --git a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala index 3714878a42ee..ec892976b7f9 100644 --- a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala @@ -842,6 +842,10 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { prefix ~~ idx.toString ~~ "|" ~~ tpeText ~~ "|" ~~ argsText ~~ "|" ~~ contentText ~~ postfix case CapturesAndResult(refs, parent) => changePrec(GlobalPrec)("^{" ~ Text(refs.map(toText), ", ") ~ "}" ~ toText(parent)) + case QualifiedTypeTree(parent, paramName, predicate) => + paramName match + case Some(name) => "{" ~ toText(name) ~ ": " ~ toText(parent) ~ " with " ~ toText(predicate) ~ "}" + case None => toText(parent) ~ " with " ~ toText(predicate) case ContextBoundTypeTree(tycon, pname, ownName) => toText(pname) ~ " : " ~ toText(tycon) ~ (" as " ~ toText(ownName) `provided` !ownName.isEmpty) case _ => diff --git a/compiler/src/dotty/tools/dotc/qualified_types/EGraph.scala b/compiler/src/dotty/tools/dotc/qualified_types/EGraph.scala new file mode 100644 index 000000000000..1efbac3058dc --- /dev/null +++ b/compiler/src/dotty/tools/dotc/qualified_types/EGraph.scala @@ -0,0 +1,477 @@ +package dotty.tools.dotc.qualified_types + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.ListBuffer + +import dotty.tools.dotc.ast.tpd.{ + closureDef, + singleton, + Apply, + ConstantTree, + Ident, + Lambda, + Literal, + New, + Select, + This, + Tree, + TreeMap, + TreeOps, + TypeApply, + TypeTree +} +import dotty.tools.dotc.config.Printers +import dotty.tools.dotc.core.Constants.Constant +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Contexts.ctx +import dotty.tools.dotc.core.Decorators.i +import dotty.tools.dotc.core.Flags +import dotty.tools.dotc.core.Hashable.Binders +import dotty.tools.dotc.core.Names.Designator +import dotty.tools.dotc.core.StdNames.nme +import dotty.tools.dotc.core.Symbols.{defn, NoSymbol, Symbol} +import dotty.tools.dotc.core.Types.{ + AppliedType, + CachedConstantType, + CachedProxyType, + ConstantType, + LambdaType, + MethodType, + NamedType, + NoPrefix, + SingletonType, + SkolemType, + TermParamRef, + TermRef, + Type, + TypeRef, + TypeVar, + ValueType +} +import dotty.tools.dotc.core.Uniques +import dotty.tools.dotc.qualified_types.ENode.Op +import dotty.tools.dotc.transform.TreeExtractors.BinaryOp +import dotty.tools.dotc.util.{EqHashMap, HashMap} +import dotty.tools.dotc.util.Spans.Span +import dotty.tools.dotc.reporting +import dotty.tools.dotc.config.Printers + +import annotation.threadUnsafe as tu +import reflect.ClassTag + +final class EGraph(_ctx: Context, checksEnabled: Boolean = true): + + /** Cache for unique E-Nodes + * + * Invariant: Each key is `eq` to its associated value. + * + * Invariant: If a node is in this map, then its children also are. + */ + private val index: HashMap[ENode, ENode] = HashMap() + + + private val idOf: EqHashMap[ENode, Int] = EqHashMap() + + /** Map from nodes to their unique, canonical representations. + * + * Invariant: After a call to [[repair]], if a node is in the index but not + * in this map, then it is its own representant and it is canonical. + * + * Invariant: After a call to [[repair]], values of this map are canonical. + */ + private val representantOf: EqHashMap[ENode, ENode] = EqHashMap() + + /** Map from child nodes to their parent nodes + * + * Invariant: After a call to [[repair]], values of this map are canonical. + */ + private val usedBy: EqHashMap[ENode, mutable.Set[ENode]] = EqHashMap() + + /** Worklist for nodes that need to be repairedConstantType(Constant(value). + * + * This queue is filled by [[merge]] and processed by [[repair]]. + * + * Invariant: After a call to [[repair]], this queue is empty. + */ + private val worklist = mutable.Queue.empty[ENode] + + val trueNode: ENode.Atom = constant(true) + val falseNode: ENode.Atom = constant(false) + val zeroIntNode: ENode.Atom = constant(0) + val minusOneIntNode: ENode.Atom = constant(-1) + val oneIntNode: ENode.Atom = constant(1) + + /** Returns the canonical node for the given constant value */ + def constant(value: Any): ENode.Atom = + val node = ENode.Atom(ConstantType(Constant(value))(using _ctx)) + idOf.getOrElseUpdate(node, idOf.size) + index.getOrElseUpdate(node, node).asInstanceOf[ENode.Atom] + + /** Adds the given node to the E-Graph, returning its canonical representant. + * + * Pre-condition: The node must be normalized, and its children must be + * canonical. + */ + private def unique(node: ENode): ENode = + if index.contains(node) then + representant(index(node)) + else + index.update(node, node) + idOf.update(node, idOf.size) + node match + case ENode.Atom(tp) => + () + case ENode.Constructor(sym) => + () + case ENode.Select(qual, member) => + addUse(qual, node) + case ENode.Apply(fn, args) => + addUse(fn, node) + for arg <- args do + addUse(arg, node) + case ENode.OpApply(op, args) => + for arg <- args do + addUse(arg, node) + case ENode.TypeApply(fn, args) => + addUse(fn, node) + case ENode.Lambda(paramTps, retTp, body) => + addUse(body, node) + node + node + + private def representant(node: ENode): ENode = + representantOf.get(node) match + case None => node + case Some(repr) => + // There must be no cycles in the `representantOf` map. + // If a node is canonical, it must have no representant. + assert(repr ne node, s"Node $node has itself as representant ($repr)") + representant(repr) + + def assertCanonical(node: ENode): Unit = + if checksEnabled then + // By the invariants, if a node is in the index (meaning it is tracked by + // this E-Graph), and has no representant, then it is itself a canonical + // node. We double-check by forcing a deep canonicalization. + assert(index.contains(node) && index(node) == node, s"Node $node is not unique in this E-Graph") + assert(!representantOf.contains(node), s"Node $node has a representant: ${representantOf(node)}") + val canonical = canonicalize(node) + assert(node eq canonical, s"Recanonicalization of $node did not return itself, but $canonical") + + private def addUse(child: ENode, parent: ENode): Unit = + usedBy.getOrElseUpdate(child, mutable.Set.empty) += parent + + override def toString(): String = + s"EGraph{\nindex = $index,\nrepresentantOf = $representantOf,\nusedBy = $usedBy,\nworklist = $worklist}\n" + + def toDot()(using Context): String = + val sb = new StringBuilder() + sb.append("digraph EGraph {\nnode [height=.1 shape=record]\n") + for node <- index.valuesIterator do + sb.append(node.toDot()) + for (node, repr) <- representantOf.iterator do + sb.append(s"${node.dotId()} -> ${repr.dotId()} [style=dotted]\n") + for (child, parents) <- usedBy.iterator do + for parent <- parents do + sb.append(s"${child.dotId()} -> ${parent.dotId()} [style=dashed]\n") + sb.append("}\n") + sb.toString() + + def debugString()(using _ctx: Context): String = + given Context = _ctx.withoutColors + index + .valuesIterator + .toList + .groupBy(representant) + .toList + .sortBy((repr, members) => repr.showNoBreak) + .map((repr, members) => repr.showNoBreak + ": " + members.filter(_ ne repr).map(_.showNoBreak).sorted.mkString("{", ", ", "}")) + .mkString("", "\n", "\n") + + + private inline def show(enode: ENode): String = + enode.showNoBreak(using _ctx) + + private inline def trace[T](inline message: String)(inline f: T): T = + reporting.trace(message, Printers.qualifiedTypes)(f)(using _ctx) + + def equiv(node1: ENode, node2: ENode): Boolean = + trace(s"equiv ${show(node1)}, ${show(node2)}"): + val repr1 = representant(node1) + val repr2 = representant(node2) + repr1 eq repr2 + + def merge(a: ENode, b: ENode): Unit = + if checksEnabled then + assert(index.contains(a) && index(a) == a, s"Node $a is not unique in this E-Graph") + assert(index.contains(b) && index(b) == b, s"Node $b is not unique in this E-Graph") + + val aRepr = representant(a) + val bRepr = representant(b) + + if aRepr eq bRepr then return + + if checksEnabled then + assert(aRepr != bRepr, s"$aRepr and $bRepr are `equals` but not `eq`") + + // Update representantOf and usedBy maps + val (newRepr, oldRepr) = order(aRepr, bRepr) + representantOf(oldRepr) = newRepr + val oldusages = usedBy.getOrElse(oldRepr, mutable.Set.empty) + usedBy.getOrElseUpdate(newRepr, mutable.Set.empty) ++= oldusages + usedBy.remove(oldRepr) + + trace(s"merge ${show(newRepr)} <-- ${show(oldRepr)}"): + // Propagate truth values over disjunctions, conjunctions and equalities + oldRepr match + case ENode.OpApply(Op.And, args) if newRepr eq trueNode => + args.foreach(merge(_, trueNode)) + case ENode.OpApply(Op.Or, args) if newRepr eq falseNode => + args.foreach(merge(_, falseNode)) + case ENode.OpApply(Op.Equal, args) if newRepr eq trueNode => + merge(args(0), args(1)) + case _ => + () + + // Enqueue all nodes that use the oldRepr for repair + trace(s"enqueue ${oldusages.map(show).mkString(", ")}"): + worklist.enqueueAll(oldusages) + () + + private def order(a: ENode, b: ENode): (ENode, ENode) = + if a.contains(b) then + (b, a) + else if b.contains(a) then + (a, b) + else + (a, b) match + case (ENode.Atom(_: ConstantType), _) => (a, b) + case (_, ENode.Atom(_: ConstantType)) => (b, a) + case (_: ENode.OpApply, _) => (a, b) + case (_, _: ENode.OpApply) => (b, a) + case (_: ENode.Constructor, _) => (a, b) + case (_, _: ENode.Constructor) => (b, a) + case (_: ENode.Select, _) => (a, b) + case (_, _: ENode.Select) => (b, a) + case (_: ENode.Apply, _) => (a, b) + case (_, _: ENode.Apply) => (b, a) + case (_: ENode.TypeApply, _) => (a, b) + case (_, _: ENode.TypeApply) => (b, a) + case (_: ENode.Atom, _) => (a, b) + case (_, _: ENode.Atom) => (b, a) + case _ => (a, b) + + def repair(): Unit = + var i = 0 + trace(s"repair (queue: ${worklist.map(show).mkString(", ")})"): + while !worklist.isEmpty do + val head = worklist.dequeue() + val headRepr = representant(head) + val headCanonical = canonicalize(head, deep = false) + if headRepr ne headCanonical then + trace(s"repair ${show(headCanonical)}, ${show(headRepr)}"): + merge(headCanonical, headRepr) + i += 1 + if i > 100 then + throw new RuntimeException("EGraph.repair: too many iterations, possible infinite loop") + + assertInvariants() + + def assertInvariants(): Unit = + if checksEnabled then + assert(worklist.isEmpty, "Worklist is not empty") + + // Check that all nodes in the index are canonical + for (node, node2) <- index.iterator do + assert(node eq node2, s"Key and value in index are not equal: $node ne $node2") + + val repr = representant(node) + assertCanonical(repr) + + def usages(node: ENode): mutable.Set[ENode] = + usedBy.getOrElse(node, mutable.Set.empty) + + node match + case ENode.Atom(tp) => () + case ENode.Constructor(sym) => () + case ENode.Select(qual, member) => + index.contains(qual) && usages(qual).contains(node) + case ENode.Apply(fn, args) => + index.contains(fn) && usages(fn).contains(node) + args.forall(arg => index.contains(arg) && usages(arg).contains(node)) + case ENode.OpApply(op, args) => + args.forall(arg => index.contains(arg) && usages(arg).contains(node)) + case ENode.TypeApply(fn, args) => + index.contains(fn) && usages(fn).contains(node) + case ENode.Lambda(paramTps, retTp, body) => + index.contains(body) && usages(body).contains(node) + + for (node, repr) <- representantOf.iterator do + assert(index.contains(node), s"Node $node is not in the index") + + for (child, parents) <- usedBy.iterator do + assertCanonical(child) + + // ----------------------------------- + // Canonicalization + // ----------------------------------- + + def canonicalize(node: ENode, deep: Boolean = true): ENode = + def recur(node: ENode): ENode = + if deep then canonicalize(node, deep) else representant(node) + trace(s"canonicalize ${show(node)}"): + representant(unique( + node match + case ENode.Atom(tp) => + node + case ENode.Constructor(sym) => + node + case ENode.Select(qual, member) => + normalizeSelect(recur(qual), member) + case ENode.Apply(fn, args) => + ENode.Apply(recur(fn), args.map(recur)) + case ENode.OpApply(op, args) => + normalizeOp(op, args.map(recur)) + case ENode.TypeApply(fn, args) => + ENode.TypeApply(recur(fn), args) + case ENode.Lambda(paramTps, retTp, body) => + ENode.Lambda(paramTps, retTp, recur(body)) + )) + + private def normalizeSelect(qual: ENode, member: Symbol): ENode = + getAppliedConstructor(qual) match + case Some(constr) => + val memberIndex = constr.fields.indexOf(member) + if memberIndex >= 0 then + val args = getTermArguments(qual) + assert(args.size == constr.fields.size) + args(memberIndex) + else + ENode.Select(qual, member) + case None => + ENode.Select(qual, member) + + private def getAppliedConstructor(node: ENode): Option[ENode.Constructor] = + node match + case ENode.Apply(fn, args) => getAppliedConstructor(fn) + case ENode.TypeApply(fn, args) => getAppliedConstructor(fn) + case node: ENode.Constructor => Some(node) + case _ => None + + private def getTermArguments(node: ENode): List[ENode] = + node match + case ENode.Apply(fn, args) => getTermArguments(fn) ::: args + case ENode.TypeApply(fn, args) => getTermArguments(fn) + case _ => Nil + + private def normalizeOp(op: ENode.Op, args: List[ENode]): ENode = + val res = op match + case Op.Equal => + assert(args.size == 2, s"Expected 2 arguments for equality, got $args") + if args(0) eq args(1) then + trueNode + else ENode.OpApply(op, args.sortBy(idOf.apply)) + case Op.And => + assert(args.size == 2, s"Expected 2 arguments for conjunction, got $args") + if (args(0) eq falseNode) || (args(1) eq falseNode) then falseNode + else if args(0) eq trueNode then args(1) + else if args(1) eq trueNode then args(0) + else ENode.OpApply(op, args) + case Op.Or => + assert(args.size == 2, s"Expected 2 arguments for disjunction, got $args") + if (args(0) eq trueNode) || (args(1) eq trueNode) then trueNode + else if args(0) eq falseNode then args(1) + else if args(1) eq falseNode then args(0) + else ENode.OpApply(op, args) + case Op.IntSum => + val (const, nonConsts) = decomposeIntSum(args) + makeIntSum(const, nonConsts) + case Op.IntMinus => + assert(args.size == 2, s"Expected 2 arguments for subtraction, got $args") + // Rewrite a - b as a + (-1) * b + val lhs = args(0) + val rhs = args(1) + val negativeRhs = unique(normalizeOp(Op.IntProduct, List(minusOneIntNode, rhs))) + normalizeOp(Op.IntSum, List(lhs, negativeRhs)) + case Op.IntProduct => + val (consts, nonConsts) = decomposeIntProduct(args) + makeIntProduct(consts, nonConsts) + case Op.IntLessThan => constFoldBinaryOp[Int, Boolean](op, args, _ < _) + case Op.IntLessEqual => constFoldBinaryOp[Int, Boolean](op, args, _ <= _) + case Op.IntGreaterThan => constFoldBinaryOp[Int, Boolean](op, args, _ > _) + case Op.IntGreaterEqual => constFoldBinaryOp[Int, Boolean](op, args, _ >= _) + case _ => + ENode.OpApply(op, args) + res + + private def constFoldBinaryOp[T: ClassTag, S](op: ENode.Op, args: List[ENode], fn: (T, T) => S): ENode = + args match + case List(ENode.Atom(ConstantType(Constant(c1: T))), ENode.Atom(ConstantType(Constant(c2: T)))) => + constant(fn(c1, c2)) + case _ => + ENode.OpApply(op, args) + + private def decomposeIntProduct(args: List[ENode]): (Int, List[ENode]) = + val factors = + args.flatMap: + case ENode.OpApply(Op.IntProduct, innerFactors) => innerFactors + case arg => List(arg) + val (consts, nonConsts) = + factors.partitionMap: + case ENode.Atom(ConstantType(Constant(c: Int))) => Left(c) + case factor => Right(factor) + (consts.product, nonConsts.sortBy(idOf.apply)) + + private def makeIntProduct(const: Int, nonConsts: List[ENode]): ENode = + if const == 0 then + zeroIntNode + else if const == 1 then + if nonConsts.isEmpty then oneIntNode + else if nonConsts.size == 1 then nonConsts.head + else ENode.OpApply(Op.IntProduct, nonConsts) + else + val constNode = constant(const) + nonConsts match + case Nil => + constNode + //case List(ENode.OpApply(Op.IntSum, summands)) => + // ENode.OpApply( + // Op.IntSum, + // summands.map(summand => unique(makeIntProduct(const, List(summand)))) + // ) + case _ => + ENode.OpApply(Op.IntProduct, constNode :: nonConsts) + + private def decomposeIntSum(args: List[ENode]): (Int, List[ENode]) = + val summands: List[ENode] = + args.flatMap: + case ENode.OpApply(Op.IntSum, innerSummands) => innerSummands + case arg => List(arg) + val decomposed: List[(Int, List[ENode])] = + summands.map: + case ENode.OpApply(Op.IntProduct, args) => + args match + case ENode.Atom(ConstantType(Constant(const: Int))) :: nonConsts => (const, nonConsts) + case nonConsts => (1, nonConsts) + case ENode.Atom(ConstantType(Constant(const: Int))) => (const, Nil) + case other => (1, List(other)) + val grouped = decomposed.groupMapReduce(_._2)(_._1)(_ + _) + val const = grouped.getOrElse(Nil, 0) + val nonConsts = + grouped + .toList + .filter((nonConsts, const) => const != 0 && !nonConsts.isEmpty) + .sortBy((nonConsts, const) => idOf(nonConsts.head)) + .map((nonConsts, const) => unique(makeIntProduct(const, nonConsts))) + (const, nonConsts) + + private def makeIntSum(const: Int, nonConsts: List[ENode]): ENode = + if const == 0 then + if nonConsts.isEmpty then zeroIntNode + else if nonConsts.size == 1 then nonConsts.head + else ENode.OpApply(Op.IntSum, nonConsts) + else + val constNode = constant(const) + if nonConsts.isEmpty then constNode + else ENode.OpApply(Op.IntSum, constNode :: nonConsts) diff --git a/compiler/src/dotty/tools/dotc/qualified_types/ENode.scala b/compiler/src/dotty/tools/dotc/qualified_types/ENode.scala new file mode 100644 index 000000000000..7b88b6e2b9e6 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/qualified_types/ENode.scala @@ -0,0 +1,638 @@ +package dotty.tools.dotc.qualified_types + +import scala.collection.mutable.ListBuffer + +import dotty.tools.dotc.ast.{tpd, untpd} +import dotty.tools.dotc.ast.tpd.TreeOps +import dotty.tools.dotc.config.Printers +import dotty.tools.dotc.config.Settings.Setting.value +import dotty.tools.dotc.core.Constants.Constant +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Contexts.ctx +import dotty.tools.dotc.core.Decorators.i +import dotty.tools.dotc.core.Flags +import dotty.tools.dotc.core.Flags.EmptyFlags +import dotty.tools.dotc.core.Hashable.Binders +import dotty.tools.dotc.core.Names.{termName, Name} +import dotty.tools.dotc.core.Names.Designator +import dotty.tools.dotc.core.StdNames.nme +import dotty.tools.dotc.core.Symbols.{defn, NoSymbol, Symbol} +import dotty.tools.dotc.core.Symbols.Symbol +import dotty.tools.dotc.core.Types.{ + AndType, + AppliedType, + CachedProxyType, + ClassInfo, + ConstantType, + LambdaType, + MethodType, + NamedType, + NoPrefix, + ParamRef, + SingletonType, + SkolemType, + TermParamRef, + TermRef, + ThisType, + Type, + TypeMap, + TypeProxy, + TypeRef, + TypeVar, + ValueType +} +import dotty.tools.dotc.parsing +import dotty.tools.dotc.printing.{Printer, Showable} +import dotty.tools.dotc.printing.GlobalPrec +import dotty.tools.dotc.printing.Texts.{stringToText, Text} +import dotty.tools.dotc.qualified_types.ENode.Op +import dotty.tools.dotc.reporting.trace +import dotty.tools.dotc.transform.TreeExtractors.{BinaryOp, UnaryOp} +import dotty.tools.dotc.util.Spans.Span + +enum ENode extends Showable: + import ENode.* + + case Atom(tp: SingletonType) + case Constructor(constr: Symbol)(val fields: List[Symbol]) + case Select(qual: ENode, member: Symbol) + case Apply(fn: ENode, args: List[ENode]) + case OpApply(fn: ENode.Op, args: List[ENode]) + case TypeApply(fn: ENode, args: List[Type]) + case Lambda(paramTps: List[Type], retTp: Type, body: ENode) + + require( + this match + case Constructor(constr) => + constr.lastKnownDenotation.isConstructor + case Lambda(paramTps, retTp, body) => + paramTps.zipWithIndex.forall: (tp, index) => + tp match + case ENodeParamRef(i, _) => i < index + case _ => true + case _ => true + ) + + def prettyString(printFullPaths: Boolean = false): String = + + def rec(n: ENode): String = + n.prettyString(printFullPaths) + + def printTp(tp: Type): String = + tp match + case tp: NamedType => + val prefixString = if isEmptyPrefix(tp.prefix) then "" else printTp(tp.prefix) + "." + prefixString + printDesignator(tp.designator) // + s"#${System.identityHashCode(tp).toHexString}" + case tp: ConstantType => + tp.value.value.toString // + s"#${System.identityHashCode(tp).toHexString}" + case tp: SkolemType => + "(?" + tp.hashCode + ": " + printTp(tp.info) + ")" + case tp: ThisType => + printTp(tp.tref) + ".this" + case tp: TypeVar => + tp.origin.paramName.toString() + case tp: ENodeParamRef => + s"arg${tp.index}" + case tp: AppliedType => + val argsString = tp.args.map(printTp).mkString(", ") + s"${printTp(tp.tycon)}[$argsString]" + case _ => + tp.toString + + def printDesignator(d: Designator): String = + d match + case d: Symbol => d.lastKnownDenotation.name.toString + case _ => d.toString + + this match + case Atom(tp) => + printTp(tp) + case Constructor(constr) => + s"new ${printDesignator(constr.lastKnownDenotation.owner)}" + case Select(qual, member) => + s"${rec(qual)}.${printDesignator(member)}" + case Apply(fn, args) => + s"${rec(fn)}(${args.map(rec).mkString(", ")})" + case OpApply(op, args) => + s"(${args.map(rec).mkString(" " + op.operatorName().toString() + " ")})" + case TypeApply(fn, args) => + s"${rec(fn)}[${args.map(printTp).mkString(", ")}]" + case Lambda(paramTps, retTp, body) => + val paramsString = paramTps.map(p => "_: " + printTp(p)).mkString(", ") + s"($paramsString): ${printTp(retTp)} => ${rec(body)}" + + + override def toText(p: Printer): Text = toText(p, false) + + def toText(p: Printer, printAddresses: Boolean): Text = + given Context = p.printerContext + + def withAddress(obj: Any, text: Text): Text = + if printAddresses then + "<" ~ text ~ s"#${System.identityHashCode(obj).toHexString}" ~ ">" + else + text + + def listToText[T](xs: List[T], fn: T => Text, sep: Text): Text = + xs.map(fn).reduceLeftOption(_ ~ sep ~ _).getOrElse("") + + withAddress( + this, + this match + case Atom(tp) => + p.toTextRef(tp) + case Constructor(constr) => + "new" ~ p.toText(constr.lastKnownDenotation.owner) + case Select(qual, member) => + qual.toText(p) ~ "." ~ p.toText(member.name) + case Apply(fn, args) => + fn.toText(p) ~ "(" ~ listToText(args, arg => p.atPrec(GlobalPrec)(arg.toText(p)), ", ") ~ ")" + case OpApply(op, args) => + assert(args.nonEmpty) + op match + // All operators with arity >= 2 + case Op.IntSum | Op.IntMinus | Op.IntProduct | + Op.IntLessThan | Op.IntLessEqual | Op.IntGreaterThan | Op.IntGreaterEqual | + Op.LongSum | Op.LongMinus | Op.LongProduct | + Op.And | Op.Or | Op.Equal | Op.NotEqual => + val opPrec = parsing.precedence(op.operatorName()) + val isRightAssoc = false + val leftPrec = if isRightAssoc then opPrec + 1 else opPrec + val rightPrec = if !isRightAssoc then opPrec + 1 else opPrec + p.changePrec(opPrec): + args.map(_.toText(p)).reduceLeft: (l, r) => + p.atPrec(leftPrec)(l) ~ " " ~ p.toText(op.operatorName()) ~ " " ~ p.atPrec(rightPrec)(r) + // Unary operators + case _ => + assert(args.length == 1) + val opPrec = parsing.precedence(op.operatorName()) + p.changePrec(opPrec): + p.toText(op.operatorName()) ~ p.atPrec(opPrec + 1)(args.head.toText(p)) + case TypeApply(fn, args) => + fn.toText(p) ~ "[" ~ listToText(args, p.toText, ",") ~ "]" + case Lambda(paramTps, retTp, body) => + val paramsText = listToText(paramTps, "_: " ~ p.toText(_), ", ") + "(" ~ paramsText ~ ")" ~ " => " ~ p.atPrec(GlobalPrec)(body.toText(p)) + ) + + def showNoBreak(using Context): String = + toText(ctx.printer).mkString() + + def dotId() = + "n" + System.identityHashCode(this).toHexString.substring(1) + + def toDot()(using _ctx: Context): String = + given Context = _ctx.withoutColors + val id = dotId() + val fields: List[ENode | String] = + this match + case Atom(tp) => this.showNoBreak :: Nil + case Constructor(constr) => this.showNoBreak :: Nil + case Select(qual, member) => qual :: member.name.show :: Nil + case Apply(fn, args) => fn :: args + case OpApply(op, args) => op.operatorName().toString() :: args + case TypeApply(fn, args) => fn :: args.map(_.show) + case Lambda(paramTps, retTp, body) => + val paramsString = paramTps.map(p => "_:" + p.show).mkString(", ") + "(" + paramsString + ") => " + retTp.show :: body :: Nil + val fieldStrings = + fields.zipWithIndex.map: (field, i) => + field match + case child: ENode => s"
"
+ case str: String => str.replace("<", "\\<").replace(">", "\\>")
+ val nodeString = s"$id [label=\"${fieldStrings.mkString("|")}\"];\n"
+ val edgesString =
+ fields.zipWithIndex.map: (field, i) =>
+ field match
+ case child: ENode => s"$id:p$i -> ${child.dotId()};\n"
+ case _ => ""
+ nodeString + edgesString.mkString
+
+ def mapTypes(f: Type => Type)(using Context): ENode =
+ this match
+ case Atom(tp) =>
+ val mappedTp = f(tp)
+ if mappedTp eq tp then
+ this
+ else
+ mappedTp match
+ case mappedTp: SingletonType => Atom(mappedTp)
+ case _ => Atom(SkolemType(mappedTp))
+ case Constructor(constr) =>
+ this
+ case node @ Select(qual, member) =>
+ node.derived(qual.mapTypes(f), member)
+ case node @ Apply(fn, args) =>
+ node.derived(fn.mapTypes(f), args.mapConserve(_.mapTypes(f)))
+ case node @ OpApply(op, args) =>
+ node.derived(op, args.mapConserve(_.mapTypes(f)))
+ case node @ TypeApply(fn, args) =>
+ node.derived(fn.mapTypes(f), args.mapConserve(f))
+ case node @ Lambda(paramTps, retTp, body) =>
+ node.derived(paramTps.mapConserve(f), f(retTp), body.mapTypes(f))
+
+ def foreachType(f: Type => Unit)(using Context): Unit =
+ this match
+ case Atom(tp) => f(tp)
+ case Constructor(_) => ()
+ case Select(qual, _) => qual.foreachType(f)
+ case Apply(fn, args) =>
+ fn.foreachType(f)
+ args.foreach(_.foreachType(f))
+ case OpApply(_, args) => args.foreach(_.foreachType(f))
+ case TypeApply(fn, args) =>
+ fn.foreachType(f)
+ args.foreach(f)
+ case Lambda(paramTps, retTp, body) =>
+ paramTps.foreach(f)
+ f(retTp)
+ body.foreachType(f)
+
+ def normalizeTypes()(using Context): ENode =
+ mapTypes(NormalizeMap())
+
+ private class NormalizeMap(using Context) extends TypeMap:
+ def apply(tp: Type): Type =
+ tp match
+ case tp: TypeVar if tp.isPermanentlyInstantiated =>
+ apply(tp.permanentInst)
+ case tp: NamedType =>
+ val dealiased = tp.dealias
+ if dealiased ne tp then
+ apply(dealiased)
+ else if tp.symbol.isStatic then
+ if tp.isInstanceOf[TermRef] then tp.symbol.termRef
+ else tp.symbol.typeRef
+ else
+ derivedSelect(tp, apply(tp.prefix))
+ case _ =>
+ mapOver(tp)
+
+ def substEParamRefs(from: Int, to: List[Type])(using Context): ENode =
+ this match
+ case Atom(tp) =>
+ mapTypes(SubstEParamsMap(from, to))
+ case Constructor(_) =>
+ this
+ case node @ Select(qual, member) =>
+ node.derived(qual.substEParamRefs(from, to), member)
+ case node @ Apply(fn, args) =>
+ node.derived(fn.substEParamRefs(from, to), args.mapConserve(_.substEParamRefs(from, to)))
+ case node @ OpApply(op, args) =>
+ node.derived(op, args.mapConserve(_.substEParamRefs(from, to)))
+ case node @ TypeApply(fn, args) =>
+ node.derived(fn.substEParamRefs(from, to), args.mapConserve(SubstEParamsMap(from, to)))
+ case node @ Lambda(paramTps, retTp, body) =>
+ node.derived(paramTps.mapConserve(SubstEParamsMap(from, to)), SubstEParamsMap(from, to)(retTp), body.substEParamRefs(from + paramTps.length, to))
+
+ private class SubstEParamsMap(from: Int, to: List[Type])(using Context) extends TypeMap:
+ override def apply(tp: Type): Type =
+ tp match
+ case ENodeParamRef(i, _) if i >= from && i < from + to.length => to(i - from)
+ case _ => mapOver(tp)
+
+ def foreach(f: ENode => Unit): Unit =
+ f(this)
+ this match
+ case Atom(_) => ()
+ case Constructor(_) => ()
+ case Select(qual, _) =>
+ qual.foreach(f)
+ case Apply(fn, args) =>
+ fn.foreach(f)
+ args.foreach(_.foreach(f))
+ case OpApply(_, args) =>
+ args.foreach(_.foreach(f))
+ case TypeApply(fn, args) =>
+ fn.foreach(f)
+ case Lambda(_, _, body) =>
+ body.foreach(f)
+
+ def contains(that: ENode): Boolean =
+ var found = false
+ foreach: node =>
+ if node eq that then
+ found = true
+ found
+
+ // -----------------------------------
+ // Conversion from E-Nodes to Trees
+ // -----------------------------------
+
+ def toTree(paramRefs: List[Type] = Nil)(using Context): tpd.Tree =
+ def mapType(tp: Type): Type = SubstEParamsMap(0, paramRefs)(tp)
+
+ trace(i"ENode.toTree $this, paramRefs: $paramRefs", Printers.qualifiedTypes):
+ this match
+ case Atom(tp) =>
+ mapType(tp) match
+ case tp1: TermParamRef => untpd.Ident(tp1.paramName).withType(tp1)
+ case tp1 => tpd.singleton(tp1)
+ case Constructor(sym) =>
+ val tycon = sym.owner.asClass.classDenot.classInfo.selfType
+ tpd.New(tycon).select(TermRef(tycon, sym))
+ case Select(qual, member) =>
+ qual.toTree(paramRefs).select(member)
+ case Apply(fn, args) =>
+ tpd.Apply(fn.toTree(paramRefs), args.map(_.toTree(paramRefs)))
+ case OpApply(op, args) =>
+ def unaryOp(symbol: Symbol): tpd.Tree =
+ require(args.length == 1)
+ args(0).toTree(paramRefs).select(symbol).appliedToNone
+ def binaryOp(symbol: Symbol): tpd.Tree =
+ require(args.length == 2)
+ args(0).toTree(paramRefs).select(symbol).appliedTo(args(1).toTree(paramRefs))
+ op match
+ case Op.IntSum =>
+ args.map(_.toTree(paramRefs)).reduceLeft(_.select(defn.Int_+).appliedTo(_))
+ case Op.IntMinus =>
+ binaryOp(defn.Int_-)
+ case Op.IntProduct =>
+ args.map(_.toTree(paramRefs)).reduceLeft(_.select(defn.Int_*).appliedTo(_))
+ case Op.LongSum =>
+ ???
+ case Op.LongMinus =>
+ ???
+ case Op.LongProduct =>
+ ???
+ case Op.Equal =>
+ args(0).toTree(paramRefs).equal(args(1).toTree(paramRefs))
+ case Op.NotEqual =>
+ val lhs = args(0).toTree(paramRefs)
+ val rhs = args(1).toTree(paramRefs)
+ tpd.applyOverloaded(lhs, nme.NE, rhs :: Nil, Nil, defn.BooleanType)
+ case Op.Not => unaryOp(defn.Boolean_!)
+ case Op.And => binaryOp(defn.Boolean_&&)
+ case Op.Or => binaryOp(defn.Boolean_||)
+ case Op.IntLessThan => binaryOp(defn.Int_<)
+ case Op.IntLessEqual => binaryOp(defn.Int_<=)
+ case Op.IntGreaterThan => binaryOp(defn.Int_>)
+ case Op.IntGreaterEqual => binaryOp(defn.Int_>=)
+ case TypeApply(fn, args) =>
+ tpd.TypeApply(fn.toTree(paramRefs), args.map(tp => tpd.TypeTree(mapType(tp), false)))
+ case Lambda(paramTps, retTp, body) =>
+ val myParamNames = paramTps.zipWithIndex.map((tp, i) => termName("param" + (paramRefs.size + i)))
+ def computeParamTypes(mt: MethodType) =
+ val reversedParamRefs = mt.paramRefs.reverse
+ paramTps.zipWithIndex.map((tp, i) => SubstEParamsMap(0, reversedParamRefs.take(i) ::: paramRefs)(tp))
+ val mt = MethodType(myParamNames)(computeParamTypes, _ => retTp)
+ tpd.Lambda(mt, myParamRefTrees =>
+ val myParamRefs = myParamRefTrees.map(_.tpe).reverse
+ body.toTree(myParamRefs ::: paramRefs)
+ )
+
+object ENode:
+ private def isEmptyPrefix(tp: Type): Boolean =
+ tp match
+ case tp: NoPrefix.type =>
+ true
+ case tp: ThisType =>
+ tp.tref.designator match
+ case d: Symbol => d.lastKnownDenotation.name.toTermName == nme.EMPTY_PACKAGE
+ case _ => false
+ case _ => false
+
+
+ enum Op:
+ case IntSum
+ case IntMinus
+ case IntProduct
+ case LongSum
+ case LongMinus
+ case LongProduct
+ case Equal
+ case NotEqual
+ case Not
+ case And
+ case Or
+ case IntLessThan
+ case IntLessEqual
+ case IntGreaterThan
+ case IntGreaterEqual
+
+ def operatorName(): Name =
+ this match
+ case IntSum => nme.Plus
+ case IntMinus => nme.Minus
+ case IntProduct => nme.Times
+ case LongSum => nme.Plus
+ case LongMinus => nme.Minus
+ case LongProduct => nme.Times
+ case Equal => nme.Equals
+ case NotEqual => nme.NotEquals
+ case Not => nme.Not
+ case And => nme.And
+ case Or => nme.Or
+ case IntLessThan => nme.Le
+ case IntLessEqual => nme.Lt
+ case IntGreaterThan => nme.Gt
+ case IntGreaterEqual => nme.Ge
+
+ // -----------------------------------
+ // Conversion from Trees to E-Nodes
+ // -----------------------------------
+
+ def fromTree(
+ tree: tpd.Tree,
+ paramSyms: List[Symbol] = Nil,
+ paramTps: List[Type] = Nil
+ )(using Context): Option[ENode] =
+ val d = defn // Need a stable path to match on `defn` members
+
+ def binaryOpNode(op: ENode.Op, lhs: tpd.Tree, rhs: tpd.Tree): Option[ENode] =
+ for
+ lhsNode <- fromTree(lhs, paramSyms, paramTps)
+ rhsNode <- fromTree(rhs, paramSyms, paramTps)
+ yield OpApply(op, List(lhsNode, rhsNode))
+
+ def unaryOpNode(op: ENode.Op, arg: tpd.Tree): Option[ENode] =
+ for argNode <- fromTree(arg, paramSyms, paramTps) yield OpApply(op, List(argNode))
+
+ def isValidEqual(sym: Symbol, lhs: tpd.Tree, rhs: tpd.Tree): Boolean =
+ def lhsClass = lhs.tpe.classSymbol
+ sym == defn.Int_==
+ || sym == defn.Boolean_==
+ || sym == defn.Any_== && lhsClass == defn.StringClass
+ || sym.name == nme.EQ && lhsClass.exists && hasCaseClassEquals(lhsClass)
+
+ trace(s"ENode.fromTree $tree", Printers.qualifiedTypes):
+ tree match
+ case tpd.Literal(_) | tpd.Ident(_) | tpd.This(_)
+ if tree.tpe.isInstanceOf[SingletonType] && tpd.isIdempotentExpr(tree) =>
+ Some(Atom(substParamRefs(tree.tpe, paramSyms, paramTps).asInstanceOf[SingletonType]))
+ case tpd.Select(tpd.New(_), nme.CONSTRUCTOR) =>
+ constructorNode(tree.symbol)
+ case tree: tpd.Select if isCaseClassApply(tree.symbol) =>
+ constructorNode(tree.symbol.owner.linkedClass.primaryConstructor)
+ case tpd.Select(qual, name) =>
+ for qualNode <- fromTree(qual, paramSyms, paramTps) yield Select(qualNode, tree.symbol)
+ case BinaryOp(lhs, sym, rhs) if isValidEqual(sym, lhs, rhs) => binaryOpNode(ENode.Op.Equal, lhs, rhs)
+ case BinaryOp(lhs, d.Int_!= | d.Boolean_!=, rhs) => binaryOpNode(ENode.Op.NotEqual, lhs, rhs)
+ case UnaryOp(d.Boolean_!, arg) => unaryOpNode(ENode.Op.Not, arg)
+ case BinaryOp(lhs, d.Boolean_&&, rhs) => binaryOpNode(ENode.Op.And, lhs, rhs)
+ case BinaryOp(lhs, d.Boolean_||, rhs) => binaryOpNode(ENode.Op.Or, lhs, rhs)
+ case BinaryOp(lhs, d.Int_+, rhs) => binaryOpNode(ENode.Op.IntSum, lhs, rhs)
+ case BinaryOp(lhs, d.Int_-, rhs) => binaryOpNode(ENode.Op.IntMinus, lhs, rhs)
+ case BinaryOp(lhs, d.Int_*, rhs) => binaryOpNode(ENode.Op.IntProduct, lhs, rhs)
+ case BinaryOp(lhs, d.Int_<, rhs) => binaryOpNode(ENode.Op.IntLessThan, lhs, rhs)
+ case BinaryOp(lhs, d.Int_<=, rhs) => binaryOpNode(ENode.Op.IntLessEqual, lhs, rhs)
+ case BinaryOp(lhs, d.Int_>, rhs) => binaryOpNode(ENode.Op.IntGreaterThan, lhs, rhs)
+ case BinaryOp(lhs, d.Int_>=, rhs) => binaryOpNode(ENode.Op.IntGreaterEqual, lhs, rhs)
+ case tpd.Apply(fun, args) =>
+ for
+ funNode <- fromTree(fun, paramSyms, paramTps)
+ argsNodes <- args.map(fromTree(_, paramSyms, paramTps)).sequence
+ yield ENode.Apply(funNode, argsNodes)
+ case tpd.TypeApply(fun, args) =>
+ for funNode <- fromTree(fun, paramSyms, paramTps)
+ yield ENode.TypeApply(funNode, args.map(tp => substParamRefs(tp.tpe, paramSyms, paramTps)))
+ case tpd.closureDef(defDef) =>
+ defDef.symbol.info.dealias match
+ case mt: MethodType =>
+ assert(defDef.termParamss.size == 1, "closure is expected to have a single parameter list")
+ var newParamSyms: List[Symbol] = paramSyms
+ var newParamTps: List[Type] = paramTps
+ val myParamSyms: List[Symbol] = defDef.termParamss.head.map(_.symbol)
+ val myParamTps: List[Type] = mt.paramInfos
+ for (myParamSym, myParamTp) <- myParamSyms.zip(myParamTps) do
+ newParamTps = substParamRefs(myParamTp, newParamSyms, newParamTps) :: newParamTps
+ newParamSyms = myParamSym :: newParamSyms
+ val myRetTp = substParamRefs(mt.resType, newParamSyms, newParamTps)
+ for body <- fromTree(defDef.rhs, newParamSyms, newParamTps)
+ yield ENode.Lambda(newParamTps.take(myParamTps.size), myRetTp, body)
+ case _ => None
+ case _ =>
+ None
+
+ private def constructorNode(constr: Symbol)(using Context): Option[ENode.Constructor] =
+ val clazz = constr.owner
+ if hasCaseClassEquals(clazz) then
+ val isPrimaryConstructor = constr.denot.isPrimaryConstructor
+ val fieldsRaw = clazz.denot.asClass.paramAccessors.filter(isPrimaryConstructor && _.isStableMember)
+ val constrParams = constr.paramSymss.flatten.filter(_.isTerm)
+ val fields = constrParams.map(p => fieldsRaw.find(_.name == p.name).getOrElse(NoSymbol))
+ Some(ENode.Constructor(constr)(fields))
+ else
+ None
+
+ private def hasCaseClassEquals(clazz: Symbol)(using Context): Boolean =
+ val equalsMethod = clazz.info.decls.lookup(nme.equals_)
+ val equalsNotOverriden = !equalsMethod.exists || equalsMethod.is(Flags.Synthetic)
+ clazz.isClass && clazz.is(Flags.Case) && equalsNotOverriden
+
+ private def isCaseClassApply(meth: Symbol)(using Context): Boolean =
+ meth.name == nme.apply
+ && meth.flags.is(Flags.Synthetic)
+ && meth.owner.linkedClass.is(Flags.Case)
+
+ def substParamRefs(tp: Type, paramSyms: List[Symbol], paramTps: List[Type])(using Context): Type =
+ trace(i"substParamRefs($tp, $paramSyms, $paramTps)", Printers.qualifiedTypes):
+ tp.subst(paramSyms, paramTps.zipWithIndex.map((tp, i) => ENodeParamRef(i, tp)).toList)
+
+ def selfify(tree: tpd.Tree)(using Context): Option[ENode.Lambda] =
+ trace(i"ENode.selfify $tree", Printers.qualifiedTypes):
+ fromTree(tree) match
+ case Some(treeNode) =>
+ Some(ENode.Lambda(
+ List(tree.tpe),
+ defn.BooleanType,
+ OpApply(ENode.Op.Equal, List(treeNode, ENode.Atom(ENodeParamRef(0, tree.tpe))))
+ ))
+ case None => None
+
+ // -----------------------------------
+ // Assumptions retrieval
+ // -----------------------------------
+
+ def assumptions(node: ENode)(using Context): List[ENode] =
+ trace(i"assumptions($node)", Printers.qualifiedTypes):
+ node match
+ case Atom(tp: SingletonType) => termAssumptions(tp) ++ typeAssumptions(tp)
+ case n: Constructor => Nil
+ case n: Select => assumptions(n.qual)
+ case n: Apply => assumptions(n.fn) ++ n.args.flatMap(assumptions)
+ case n: OpApply => n.args.flatMap(assumptions)
+ case n: TypeApply => assumptions(n.fn)
+ case n: Lambda => Nil
+
+ private def termAssumptions(tp: SingletonType)(using Context): List[ENode] =
+ trace(i"termAssumptions($tp)", Printers.qualifiedTypes):
+ tp match
+ case tp: TermRef =>
+ tp.symbol.info match
+ case QualifiedType(_, _) => Nil
+ case _ =>
+ tp.symbol.defTree match
+ case valDef: tpd.ValDef if !valDef.rhs.isEmpty && !valDef.symbol.is(Flags.Lazy) =>
+ fromTree(valDef.rhs) match
+ case Some(treeNode) => OpApply(ENode.Op.Equal, List(treeNode, Atom(tp))) :: assumptions(treeNode)
+ case None => Nil
+ case _ => Nil
+ case _ => Nil
+
+ private def typeAssumptions(rootTp: SingletonType)(using Context): List[ENode] =
+ def rec(tp: Type): List[ENode] =
+ tp match
+ case QualifiedType(parent, qualifier) => qualifier.body.substEParamRefs(0, List(rootTp)) :: assumptions(qualifier.body) ::: rec(parent)
+ case tp: SingletonType if tp ne rootTp => List(OpApply(ENode.Op.Equal, List(Atom(tp), Atom(rootTp))))
+ case tp: TypeProxy => rec(tp.underlying)
+ case AndType(tp1, tp2) => rec(tp1) ++ rec(tp2)
+ case _ => Nil
+ trace(i"typeAssumptions($rootTp)", Printers.qualifiedTypes):
+ rec(rootTp)
+
+ // -----------------------------------
+ // Utils
+ // -----------------------------------
+
+ extension (n: Atom)
+ def derived(tp: SingletonType): ENode.Atom =
+ if n.tp eq tp then n
+ else ENode.Atom(tp)
+
+ extension (n: Constructor)
+ def derived(constr: Symbol): ENode.Constructor =
+ if n.constr eq constr then n
+ else ENode.Constructor(constr)(n.fields)
+
+ extension (n: Select)
+ def derived(qual: ENode, member: Symbol): ENode.Select =
+ if (n.qual eq qual) && (n.member eq member) then n
+ else ENode.Select(qual, member)
+
+ extension (n: Apply)
+ def derived(fn: ENode, args: List[ENode]): ENode.Apply =
+ if (n.fn eq fn) && (n.args eq args) then n
+ else ENode.Apply(fn, args)
+
+ extension (n: OpApply)
+ def derived(op: ENode.Op, args: List[ENode]): ENode.OpApply =
+ if (n.fn eq op) && (n.args eq args) then n
+ else ENode.OpApply(op, args)
+
+ extension (n: TypeApply)
+ def derived(fn: ENode, args: List[Type]): ENode.TypeApply =
+ if (n.fn eq fn) && (n.args eq args) then n
+ else ENode.TypeApply(fn, args)
+
+ extension (n: Lambda)
+ def derived(paramTps: List[Type], retTp: Type, body: ENode): ENode.Lambda =
+ if (n.paramTps eq paramTps) && (n.retTp eq retTp) && (n.body eq body) then n
+ else ENode.Lambda(paramTps, retTp, body)
+
+ // -----------------------------------
+ // Utils
+ // -----------------------------------
+
+ extension [T](xs: List[Option[T]])
+ private def sequence: Option[List[T]] =
+ var result = List.newBuilder[T]
+ var current = xs
+ while current.nonEmpty do
+ current.head match
+ case Some(x) =>
+ result += x
+ current = current.tail
+ case None =>
+ return None
+ Some(result.result())
diff --git a/compiler/src/dotty/tools/dotc/qualified_types/ENodeParamRef.scala b/compiler/src/dotty/tools/dotc/qualified_types/ENodeParamRef.scala
new file mode 100644
index 000000000000..fdf4513d7d97
--- /dev/null
+++ b/compiler/src/dotty/tools/dotc/qualified_types/ENodeParamRef.scala
@@ -0,0 +1,23 @@
+package dotty.tools.dotc.qualified_types
+import dotty.tools.dotc.core.Types.{
+ SingletonType,
+ CachedProxyType,
+ Type
+}
+import dotty.tools.dotc.core.Contexts.Context
+import dotty.tools.dotc.core.Hashable.Binders
+
+
+/** Reference to the argument of an [[ENode.Lambda]].
+ *
+ * @param index
+ * Debruijn index of the argument, starting from 0
+ * @param underyling
+ * Underlying type of the argument
+ */
+final case class ENodeParamRef(index: Int, underlying: Type) extends CachedProxyType, SingletonType:
+ override def underlying(using Context): Type = underlying
+ override def computeHash(bs: Binders): Int = doHash(bs, index, underlying)
+ def derivedENodeParamRef(index: Int, underlying: Type): ENodeParamRef =
+ if index == this.index && (underlying eq this.underlying) then this
+ else ENodeParamRef(index, underlying)
diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifiedAnnotation.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifiedAnnotation.scala
new file mode 100644
index 000000000000..83738a188ced
--- /dev/null
+++ b/compiler/src/dotty/tools/dotc/qualified_types/QualifiedAnnotation.scala
@@ -0,0 +1,39 @@
+package dotty.tools.dotc.qualified_types
+
+import dotty.tools.dotc.ast.tpd.Tree
+import dotty.tools.dotc.core.Annotations.Annotation
+import dotty.tools.dotc.core.Contexts.{ctx, Context}
+import dotty.tools.dotc.core.Decorators.i
+import dotty.tools.dotc.core.Symbols.defn
+import dotty.tools.dotc.core.Types.{TermLambda, TermParamRef, Type, ConstantType, TypeMap}
+import dotty.tools.dotc.printing.Printer
+import dotty.tools.dotc.printing.Texts.Text
+import dotty.tools.dotc.printing.Texts.stringToText
+import dotty.tools.dotc.core.Constants.Constant
+import dotty.tools.dotc.report
+
+case class QualifiedAnnotation(qualifier: ENode.Lambda) extends Annotation:
+
+ override def tree(using Context): Tree = qualifier.toTree()
+
+ override def symbol(using Context) = defn.QualifiedAnnot
+
+ override def derivedAnnotation(tree: Tree)(using Context): Annotation = ???
+
+ private def derivedAnnotation(qualifier: ENode.Lambda)(using Context): Annotation =
+ if qualifier eq this.qualifier then this
+ else QualifiedAnnotation(qualifier)
+
+ override def toText(printer: Printer): Text =
+ "with " ~ qualifier.body.toText(printer)
+
+ override def mapWith(tm: TypeMap)(using Context): Annotation =
+ derivedAnnotation(qualifier.mapTypes(tm).asInstanceOf[ENode.Lambda])
+
+ override def refersToParamOf(tl: TermLambda)(using Context): Boolean =
+ var res = false
+ qualifier.foreachType: tp =>
+ tp.stripped match
+ case TermParamRef(tl1, _) if tl eq tl1 => res = true
+ case _ => ()
+ res
diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifiedType.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifiedType.scala
new file mode 100644
index 000000000000..22394db100f2
--- /dev/null
+++ b/compiler/src/dotty/tools/dotc/qualified_types/QualifiedType.scala
@@ -0,0 +1,41 @@
+package dotty.tools.dotc.qualified_types
+
+import dotty.tools.dotc.ast.tpd
+import dotty.tools.dotc.core.Annotations.Annotation
+import dotty.tools.dotc.core.Contexts.{ctx, Context}
+import dotty.tools.dotc.core.Types.{AnnotatedType, Type, ErrorType}
+import dotty.tools.dotc.core.Decorators.em
+import dotty.tools.dotc.typer.ErrorReporting.errorType
+
+/** A qualified type is internally represented as a type annotated with a
+ * `@qualified` annotation.
+ */
+object QualifiedType:
+ /** Extractor for qualified types.
+ *
+ * @param tp
+ * the type to deconstruct
+ * @return
+ * a pair containing the parent type and the qualifier tree (a lambda) on
+ * success, [[None]] otherwise
+ */
+ def unapply(tp: Type)(using Context): Option[(Type, ENode.Lambda)] =
+ tp match
+ case AnnotatedType(parent, QualifiedAnnotation(qualifier)) =>
+ Some((parent, qualifier))
+ case _ =>
+ None
+
+ def apply(parent: Type, qualifier: ENode.Lambda)(using Context): Type =
+ AnnotatedType(parent, QualifiedAnnotation(qualifier))
+
+ def apply(parent: Type, annot: Annotation)(using Context): Type =
+ annot match
+ case annot: QualifiedAnnotation => AnnotatedType(parent, annot)
+ case _ => apply(parent, annot.arguments(0))
+
+ def apply(parent: Type, annotTree: tpd.Tree)(using Context): Type =
+ val arg = tpd.allTermArguments(annotTree)(0)
+ ENode.fromTree(arg) match
+ case Some(qualifier: ENode.Lambda) => apply(parent, qualifier)
+ case _ => errorType(em"Invalid qualifier: $arg", annotTree.srcPos)
diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifiedTypes.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifiedTypes.scala
new file mode 100644
index 000000000000..dec4e3f02af0
--- /dev/null
+++ b/compiler/src/dotty/tools/dotc/qualified_types/QualifiedTypes.scala
@@ -0,0 +1,133 @@
+package dotty.tools.dotc.qualified_types
+
+import dotty.tools.dotc.ast.tpd
+import dotty.tools.dotc.ast.tpd.{
+ Apply,
+ Block,
+ EmptyTree,
+ Ident,
+ If,
+ Lambda,
+ Literal,
+ New,
+ Select,
+ SeqLiteral,
+ This,
+ Throw,
+ Tree,
+ TypeApply,
+ Typed,
+ given
+}
+import dotty.tools.dotc.config.Printers
+import dotty.tools.dotc.core.Atoms
+import dotty.tools.dotc.core.Constants.Constant
+import dotty.tools.dotc.core.Contexts.{ctx, Context}
+import dotty.tools.dotc.core.Decorators.{em, i, toTermName}
+import dotty.tools.dotc.core.StdNames.nme
+import dotty.tools.dotc.core.Symbols.{defn, Symbol}
+import dotty.tools.dotc.core.Types.{
+ AndType,
+ ConstantType,
+ ErrorType,
+ MethodType,
+ OrType,
+ SkolemType,
+ TermRef,
+ Type,
+ TypeProxy
+}
+import dotty.tools.dotc.util.SrcPos
+import dotty.tools.dotc.report
+import dotty.tools.dotc.reporting.trace
+
+object QualifiedTypes:
+ /** Does the type `tp1` imply the qualifier `qualifier2`?
+ *
+ * Used by [[dotty.tools.dotc.core.TypeComparer]] to compare qualified types.
+ *
+ * Note: the logic here is similar to [[Type#derivesAnnotWith]] but
+ * additionally handle comparisons with [[SingletonType]]s.
+ */
+ def typeImplies(tp1: Type, qualifier2: ENode.Lambda, solver: QualifierSolver)(using Context): Boolean =
+ def trySelfifyType() =
+ val ENode.Lambda(List(paramTp), _, _) = qualifier2: @unchecked
+ ENode.selfify(tpd.singleton(tp1)) match
+ case Some(qualifier1) => solver.implies(qualifier1, qualifier2)
+ case None => false
+ trace(i"typeImplies $tp1 --> ${qualifier2.body}", Printers.qualifiedTypes):
+ tp1 match
+ case QualifiedType(parent1, qualifier1) =>
+ solver.implies(qualifier1, qualifier2)
+ case tp1: TermRef =>
+ def trySelfifyRef() =
+ tp1.underlying match
+ case QualifiedType(_, _) => false
+ case _ => trySelfifyType()
+ typeImplies(tp1.underlying, qualifier2, solver) || trySelfifyRef()
+ case tp1: ConstantType =>
+ trySelfifyType()
+ case tp1: TypeProxy =>
+ typeImplies(tp1.underlying, qualifier2, solver)
+ case AndType(tp11, tp12) =>
+ typeImplies(tp11, qualifier2, solver) || typeImplies(tp12, qualifier2, solver)
+ case OrType(tp11, tp12) =>
+ typeImplies(tp11, qualifier2, solver) && typeImplies(tp12, qualifier2, solver)
+ case _ =>
+ val trueQualifier: ENode.Lambda = ENode.Lambda(
+ List(defn.AnyType),
+ defn.BooleanType,
+ ENode.Atom(ConstantType(Constant(true)))
+ )
+ solver.implies(trueQualifier, qualifier2)
+
+ /** Try to adapt the tree to the given type `pt`
+ *
+ * Returns [[EmptyTree]] if `pt` does not contain qualifiers or if the tree
+ * cannot be adapted, or the adapted tree otherwise.
+ *
+ * Used by [[dotty.tools.dotc.core.Typer]].
+ */
+ def adapt(tree: Tree, pt: Type)(using Context): Tree =
+ if containsQualifier(pt) then
+ trace(i"adapt $tree to qualified type $pt", Printers.qualifiedTypes):
+ if tree.tpe.hasAnnotation(defn.RuntimeCheckedAnnot) then
+ if checkContainsSkolem(pt, tree.srcPos) then
+ tpd.evalOnce(tree): e =>
+ If(
+ e.isInstance(pt),
+ e.asInstance(pt),
+ Throw(New(defn.IllegalArgumentExceptionType, List()))
+ )
+ else
+ tree.withType(ErrorType(em""))
+ else
+ ENode.selfify(tree) match
+ case Some(qualifier) =>
+ val selfifiedTp = QualifiedType(tree.tpe, qualifier)
+ if selfifiedTp <:< pt then tree.cast(selfifiedTp) else EmptyTree
+ case None =>
+ EmptyTree
+ else
+ EmptyTree
+
+ def containsQualifier(tp: Type)(using Context): Boolean =
+ tp match
+ case QualifiedType(_, _) => true
+ case tp: TypeProxy => containsQualifier(tp.underlying)
+ case AndType(tp1, tp2) => containsQualifier(tp1) || containsQualifier(tp2)
+ case OrType(tp1, tp2) => containsQualifier(tp1) || containsQualifier(tp2)
+ case _ => false
+
+ def checkContainsSkolem(tp: Type, pos: SrcPos)(using Context): Boolean =
+ var res = true
+ tp.foreachPart:
+ case QualifiedType(_, qualifier) =>
+ qualifier.foreachType: rootTp =>
+ rootTp.foreachPart:
+ case tp: SkolemType =>
+ report.error(em"The qualified type $qualifier cannot be checked at runtime", pos)
+ res = false
+ case _ => ()
+ case _ => ()
+ res
diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala
new file mode 100644
index 000000000000..10792bb5cf3a
--- /dev/null
+++ b/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala
@@ -0,0 +1,57 @@
+package dotty.tools.dotc.qualified_types
+
+import ENode.{Lambda, OpApply, Op}
+
+import dotty.tools.dotc.config.Printers
+import dotty.tools.dotc.core.Contexts.{ctx, Context}
+import dotty.tools.dotc.core.Symbols.defn
+import dotty.tools.dotc.core.Types.{Type, TypeVar, TypeMap}
+import dotty.tools.dotc.core.Decorators.i
+import dotty.tools.dotc.printing.Showable
+
+class QualifierSolver(using Context):
+
+ def implies(node1: ENode.Lambda, node2: ENode.Lambda) =
+ require(node1.paramTps.length == 1)
+ require(node2.paramTps.length == 1)
+ val node1Inst = node1.normalizeTypes().asInstanceOf[ENode.Lambda]
+ val node2Inst = node2.normalizeTypes().asInstanceOf[ENode.Lambda]
+ val paramTp1 = node1Inst.paramTps.head
+ val paramTp2 = node2Inst.paramTps.head
+ if paramTp1 frozen_<:< paramTp2 then
+ impliesRec(subsParamRefTps(node1Inst.body, node2Inst), node2Inst.body)
+ else if paramTp2 frozen_<:< paramTp1 then
+ impliesRec(node1Inst.body, subsParamRefTps(node2Inst.body, node1Inst))
+ else
+ false
+
+ private def subsParamRefTps(node1Body: ENode, node2: ENode.Lambda): ENode =
+ val paramRefs = node2.paramTps.zipWithIndex.map((tp, i) => ENodeParamRef(i, tp))
+ node1Body.substEParamRefs(0, paramRefs)
+
+ private def impliesRec(node1: ENode, node2: ENode): Boolean =
+ node1 match
+ case OpApply(Op.Or, List(lhs, rhs)) =>
+ return impliesRec(lhs, node2) && impliesRec(rhs, node2)
+ case _ => ()
+
+ val assumptions = ENode.assumptions(node1) ++ ENode.assumptions(node2)
+ val node1WithAssumptions = assumptions.foldLeft(node1)((acc, a) => OpApply(Op.And, List(acc, a.normalizeTypes())))
+ impliesLeaf(EGraph(ctx), node1WithAssumptions, node2)
+
+ protected def impliesLeaf(egraph: EGraph, enode1: ENode, enode2: ENode): Boolean =
+ val node1Canonical = egraph.canonicalize(enode1)
+ val node2Canonical = egraph.canonicalize(enode2)
+ egraph.assertInvariants()
+ egraph.merge(node1Canonical, egraph.trueNode)
+ egraph.repair()
+ egraph.equiv(node2Canonical, egraph.trueNode)
+
+final class ExplainingQualifierSolver(
+ traceIndented: [T] => (String) => (=> T) => T)(using Context) extends QualifierSolver:
+
+ override protected def impliesLeaf(egraph: EGraph, enode1: ENode, enode2: ENode): Boolean =
+ traceIndented(s"${enode1.showNoBreak} --> ${enode2.showNoBreak}"):
+ val res = super.impliesLeaf(egraph, enode1, enode2)
+ if !res then println(egraph.debugString())
+ res
diff --git a/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala b/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala
index 16219055b8c0..30095b99b1c5 100644
--- a/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala
+++ b/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala
@@ -175,4 +175,14 @@ object BetaReduce:
Some(expansion1)
else None
end reduceApplication
-end BetaReduce
\ No newline at end of file
+
+ def reduceApplication(ddef: DefDef, argss: List[List[Tree]])(using Context): Option[Tree] =
+ val bindings = new ListBuffer[DefTree]()
+ reduceApplication(ddef, argss, bindings) match
+ case Some(expansion1) =>
+ val bindings1 = bindings.result()
+ Some(seq(bindings1, expansion1))
+ case None =>
+ None
+
+end BetaReduce
diff --git a/compiler/src/dotty/tools/dotc/transform/Erasure.scala b/compiler/src/dotty/tools/dotc/transform/Erasure.scala
index 9a8f5596471f..884b323654e0 100644
--- a/compiler/src/dotty/tools/dotc/transform/Erasure.scala
+++ b/compiler/src/dotty/tools/dotc/transform/Erasure.scala
@@ -602,7 +602,7 @@ object Erasure {
}
override def promote(tree: untpd.Tree)(using Context): tree.ThisTree[Type] = {
- assert(tree.hasType)
+ assert(tree.hasType, i"promote called on tree without type: ${tree.show}")
val erasedTp = erasedType(tree)
report.log(s"promoting ${tree.show}: ${erasedTp.showWithUnderlying()}")
tree.withType(erasedTp)
diff --git a/compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala b/compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala
index d2a72e10fcfc..aa470a0e9ea8 100644
--- a/compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala
+++ b/compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala
@@ -59,12 +59,7 @@ class InlinePatterns extends MiniPhase:
case Block(TypeDef(_, template: Template) :: Nil, Apply(Select(New(_),_), Nil)) if template.constr.rhs.isEmpty =>
template.body match
case List(ddef @ DefDef(`name`, _, _, _)) =>
- val bindings = new ListBuffer[DefTree]()
- BetaReduce.reduceApplication(ddef, argss, bindings) match
- case Some(expansion1) =>
- val bindings1 = bindings.result()
- seq(bindings1, expansion1)
- case None => tree
+ BetaReduce.reduceApplication(ddef, argss).getOrElse(tree)
case _ => tree
case _ => tree
diff --git a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala
index 128655debc2e..a5f88db28871 100644
--- a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala
+++ b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala
@@ -26,6 +26,7 @@ import cc.*
import dotty.tools.dotc.transform.MacroAnnotations.hasMacroAnnotation
import dotty.tools.dotc.core.NameKinds.DefaultGetterName
import ast.TreeInfo
+import dotty.tools.dotc.qualified_types.QualifiedAnnotation
object PostTyper {
val name: String = "posttyper"
@@ -208,11 +209,15 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
}
private def transformAnnot(annot: Annotation)(using Context): Annotation =
- val tree1 =
- annot match
- case _: BodyAnnotation => annot.tree
- case _ => copySymbols(annot.tree)
- annot.derivedAnnotation(transformAnnotTree(tree1))
+ annot match
+ case _: QualifiedAnnotation =>
+ annot
+ case _ =>
+ val tree1 =
+ annot match
+ case _: BodyAnnotation => annot.tree
+ case _ => copySymbols(annot.tree)
+ annot.derivedAnnotation(transformAnnotTree(tree1))
/** Transforms all annotations in the given type. */
private def transformAnnotsIn(using Context) =
diff --git a/compiler/src/dotty/tools/dotc/transform/TreeExtractors.scala b/compiler/src/dotty/tools/dotc/transform/TreeExtractors.scala
index 8d5b7c28bbbc..9ec91c79abad 100644
--- a/compiler/src/dotty/tools/dotc/transform/TreeExtractors.scala
+++ b/compiler/src/dotty/tools/dotc/transform/TreeExtractors.scala
@@ -19,6 +19,15 @@ object TreeExtractors {
}
}
+ /** Match arg.op() and extract (arg, op.symbol) */
+ object UnaryOp:
+ def unapply(t: Tree)(using Context): Option[(Symbol, Tree)] =
+ t match
+ case Apply(sel @ Select(arg, _), Nil) =>
+ Some((sel.symbol, arg))
+ case _ =>
+ None
+
/** Match new C(args) and extract (C, args).
* Also admit new C(args): T and {new C(args)}.
*/
diff --git a/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala b/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala
index a8c8ec8ce1d8..26364298e41f 100644
--- a/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala
+++ b/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala
@@ -18,6 +18,8 @@ import config.Printers.{ transforms => debug }
import patmat.Typ
import dotty.tools.dotc.util.SrcPos
+import qualified_types.QualifiedType
+
/** This transform normalizes type tests and type casts,
* also replacing type tests with singleton argument type with reference equality check
* Any remaining type tests
@@ -323,7 +325,7 @@ object TypeTestsCasts {
* The transform happens before erasure of `testType`, thus cannot be merged
* with `transformIsInstanceOf`, which depends on erased type of `testType`.
*/
- def transformTypeTest(expr: Tree, testType: Type, flagUnrelated: Boolean): Tree = testType.dealias match {
+ def transformTypeTest(expr: Tree, testType: Type, flagUnrelated: Boolean): Tree = testType.dealiasKeepRefiningAnnots match {
case tref: TermRef if tref.symbol == defn.EmptyTupleModule =>
ref(defn.RuntimeTuples_isInstanceOfEmptyTuple).appliedTo(expr)
case _: SingletonType =>
@@ -352,6 +354,16 @@ object TypeTestsCasts {
ref(defn.RuntimeTuples_isInstanceOfNonEmptyTuple).appliedTo(expr)
case AppliedType(tref: TypeRef, _) if tref.symbol == defn.PairClass =>
ref(defn.RuntimeTuples_isInstanceOfNonEmptyTuple).appliedTo(expr)
+ case QualifiedType(parent, qualifier) =>
+ qualifier.toTree() match
+ case closureDef(qualifierDef) =>
+ evalOnce(expr): e =>
+ // e.isInstanceOf[baseType] && qualifier(e.asInstanceOf[baseType])
+ val arg = e.asInstance(parent)
+ val qualifierTest = BetaReduce.reduceApplication(qualifierDef, List(List(arg))).get
+ transformTypeTest(e, parent, flagUnrelated).and(qualifierTest)
+ case tree =>
+ throw new IllegalStateException("Malformed qualifier tree: $tree, expected a closure definition")
case _ =>
val testWidened = testType.widen
defn.untestableClasses.find(testWidened.isRef(_)) match
diff --git a/compiler/src/dotty/tools/dotc/typer/ConstFold.scala b/compiler/src/dotty/tools/dotc/typer/ConstFold.scala
index bd726afe5bba..76c36bfe5c2c 100644
--- a/compiler/src/dotty/tools/dotc/typer/ConstFold.scala
+++ b/compiler/src/dotty/tools/dotc/typer/ConstFold.scala
@@ -61,17 +61,6 @@ object ConstFold:
tree.withFoldedType(Constant(targ.tpe))
case _ => tree
- private object ConstantTree:
- def unapply(tree: Tree)(using Context): Option[Constant] =
- tree match
- case Inlined(_, Nil, expr) => unapply(expr)
- case Typed(expr, _) => unapply(expr)
- case Literal(c) if c.tag == Constants.NullTag => Some(c)
- case _ =>
- tree.tpe.widenTermRefExpr.normalized.simplified match
- case ConstantType(c) => Some(c)
- case _ => None
-
extension [T <: Tree](tree: T)(using Context)
private def withFoldedType(c: Constant | Null): T =
if c == null then tree else tree.withType(ConstantType(c)).asInstanceOf[T]
diff --git a/compiler/src/dotty/tools/dotc/typer/ImportInfo.scala b/compiler/src/dotty/tools/dotc/typer/ImportInfo.scala
index e8698baa46ac..1915155f5aca 100644
--- a/compiler/src/dotty/tools/dotc/typer/ImportInfo.scala
+++ b/compiler/src/dotty/tools/dotc/typer/ImportInfo.scala
@@ -206,7 +206,7 @@ class ImportInfo(symf: Context ?=> Symbol,
/** Does this import clause or a preceding import clause enable `feature`?
*
- * @param feature a possibly quailified name, e.g.
+ * @param feature a possibly qualified name, e.g.
* strictEquality
* experimental.genericNumberLiterals
*
diff --git a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala
index f1ad0f8520f1..3d0ca2340ecf 100644
--- a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala
+++ b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala
@@ -14,6 +14,7 @@ import Checking.{checkNoPrivateLeaks, checkNoWildcard}
import cc.CaptureSet
import util.Property
import transform.Splicer
+import qualified_types.QualifiedType
trait TypeAssigner {
import tpd.*
@@ -572,7 +573,10 @@ trait TypeAssigner {
def assignType(tree: untpd.Annotated, arg: Tree, annot: Tree)(using Context): Annotated = {
assert(tree.isType) // annotating a term is done via a Typed node, can't use Annotate directly
- tree.withType(AnnotatedType(arg.tpe, Annotation(annot)))
+ if Annotations.annotClass(annot) == defn.QualifiedAnnot then
+ tree.withType(QualifiedType(arg.tpe, annot))
+ else
+ tree.withType(AnnotatedType(arg.tpe, Annotation(annot)))
}
def assignType(tree: untpd.PackageDef, pid: Tree)(using Context): PackageDef =
diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala
index 3462e8455394..121c9bd108bb 100644
--- a/compiler/src/dotty/tools/dotc/typer/Typer.scala
+++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala
@@ -47,6 +47,7 @@ import reporting.*
import Nullables.*
import NullOpsDecorator.*
import cc.{CheckCaptures, isRetainsLike}
+import qualified_types.{QualifiedTypes, QualifiedType}
import config.Config
import config.MigrationVersion
import transform.CheckUnused.OriginalName
@@ -2304,9 +2305,10 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
}
/** Type a case. */
- def typedCase(tree: untpd.CaseDef, sel: Tree, wideSelType: Type, pt: Type)(using Context): CaseDef = {
+ def typedCase(tree0: untpd.CaseDef, sel: Tree, wideSelType: Type, pt: Type)(using Context): CaseDef = {
val originalCtx = ctx
val gadtCtx: Context = ctx.fresh.setFreshGADTBounds
+ val tree = desugar.caseDef(tree0)
def caseRest(pat: Tree)(using Context) = {
val pt1 = instantiateMatchTypeProto(pat, pt) match {
@@ -2540,7 +2542,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
// untyped tree is no longer accessed after all
// accesses with typedTypeTree are done.
case None =>
- errorTree(tree, em"Something's wrong: missing original symbol for type tree")
+ errorTree(tree, em"Something's wrong: missing original symbol for type tree ${tree}")
}
case _ =>
completeTypeTree(InferredTypeTree(), pt, tree)
@@ -4816,7 +4818,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
return readapt(tree.cast(captured))
// drop type if prototype is Unit
- if pt.isRef(defn.UnitClass) then
+ if pt.isRef(defn.UnitClass, false) then
// local adaptation makes sure every adapted tree conforms to its pt
// so will take the code path that decides on inlining
val tree1 = adapt(tree, WildcardType, locked)
@@ -4868,6 +4870,11 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
case _ =>
case _ =>
+ // Try to adapt to a qualified type
+ val adapted = QualifiedTypes.adapt(tree, pt)
+ if !adapted.isEmpty then
+ return readapt(adapted)
+
def recover(failure: SearchFailureType) =
if canDefineFurther(wtp) || canDefineFurther(pt) then readapt(tree)
else
diff --git a/compiler/test/dotty/tools/dotc/CompilationTests.scala b/compiler/test/dotty/tools/dotc/CompilationTests.scala
index baf1b4d66306..3e9e8a1b6c14 100644
--- a/compiler/test/dotty/tools/dotc/CompilationTests.scala
+++ b/compiler/test/dotty/tools/dotc/CompilationTests.scala
@@ -38,6 +38,7 @@ class CompilationTests {
compileFile("tests/pos-special/sourcepath/outer/nested/Test4.scala", defaultOptions.and("-sourcepath", "tests/pos-special/sourcepath")),
compileFilesInDir("tests/pos-scala2", defaultOptions.and("-source", "3.0-migration")),
compileFilesInDir("tests/pos-custom-args/captures", defaultOptions.and("-language:experimental.captureChecking", "-language:experimental.separationChecking", "-source", "3.8")),
+ compileFilesInDir("tests/pos-custom-args/qualified-types", defaultOptions.and("-language:experimental.qualifiedTypes")),
compileFile("tests/pos-special/utf8encoded.scala", defaultOptions.and("-encoding", "UTF8")),
compileFile("tests/pos-special/utf16encoded.scala", defaultOptions.and("-encoding", "UTF16")),
compileDir("tests/pos-special/i18589", defaultOptions.and("-Wsafe-init").without("-Ycheck:all")),
@@ -150,6 +151,7 @@ class CompilationTests {
compileFilesInDir("tests/neg", defaultOptions, FileFilter.exclude(TestSources.negScala2LibraryTastyExcludelisted)),
compileFilesInDir("tests/neg-deep-subtype", allowDeepSubtypes),
compileFilesInDir("tests/neg-custom-args/captures", defaultOptions.and("-language:experimental.captureChecking", "-language:experimental.separationChecking", "-source", "3.8")),
+ compileFilesInDir("tests/neg-custom-args/qualified-types", defaultOptions.and("-language:experimental.qualifiedTypes")),
compileFile("tests/neg-custom-args/sourcepath/outer/nested/Test1.scala", defaultOptions.and("-sourcepath", "tests/neg-custom-args/sourcepath")),
compileDir("tests/neg-custom-args/sourcepath2/hi", defaultOptions.and("-sourcepath", "tests/neg-custom-args/sourcepath2", "-Werror")),
compileList("duplicate source", List(
@@ -173,6 +175,7 @@ class CompilationTests {
compileFilesInDir("tests/run", defaultOptions.and("-Wsafe-init")),
compileFilesInDir("tests/run-deep-subtype", allowDeepSubtypes),
compileFilesInDir("tests/run-custom-args/captures", allowDeepSubtypes.and("-language:experimental.captureChecking", "-language:experimental.separationChecking", "-source", "3.8")),
+ compileFilesInDir("tests/run-custom-args/qualified-types", defaultOptions.and("-language:experimental.qualifiedTypes")),
// Run tests for legacy lazy vals.
compileFilesInDir("tests/run", defaultOptions.and("-Wsafe-init", "-Ylegacy-lazy-vals", "-Ycheck-constraint-deps"), FileFilter.include(TestSources.runLazyValsAllowlist)),
).checkRuns()
diff --git a/compiler/test/dotty/tools/dotc/qualified_types/EGraphTest.scala b/compiler/test/dotty/tools/dotc/qualified_types/EGraphTest.scala
new file mode 100644
index 000000000000..0162da177cd1
--- /dev/null
+++ b/compiler/test/dotty/tools/dotc/qualified_types/EGraphTest.scala
@@ -0,0 +1,235 @@
+package dotty.tools.dotc.qualified_types
+
+import dotty.tools.DottyTest
+import dotty.tools.dotc.core.Decorators.i
+import dotty.tools.dotc.ast.tpd
+
+import org.junit.Assert.assertEquals
+import org.junit.Test
+
+class EGraphTest extends QualifiedTypesTest:
+
+ def checkImplies(fromString: String, toString: String, egraphString: String, expected: Boolean = true): Unit =
+ val src = s"""
+ |def test = {
+ | val b1: Boolean = ???
+ | val b2: Boolean = ???
+ | val b3: Boolean = ???
+ | val w: Int = ???
+ | val x: Int = ???
+ | val y: Int = ???
+ | val z: Int = ???
+ | def f(a: Int): Boolean = ???
+ | def g(a: Int): Int = ???
+ | def h(a: Int, b: Int): Int = ???
+ | def id[T](a: T): T = a
+ | type Vec[T]
+ | type Pos = {v: Int with v > 0}
+ | def len[T](v: Vec[T]): Pos = ???
+ | val v1: Vec[Int] = ???
+ | val v2: Vec[Int] = ???
+ | val v3: Vec[Int] = ???
+ | val from: Boolean = $fromString
+ | val to: Boolean = $toString
+ |}""".stripMargin
+ checkCompileExpr(src): stats =>
+ val testTree = getDefDef(stats, "test")
+ val body = testTree.rhs.asInstanceOf[tpd.Block]
+ val fromTree = getValDef(body.stats, "from").rhs
+ val toTree = getValDef(body.stats, "to").rhs
+ val egraph = EGraph(ctx, checksEnabled = true)
+ val from = ENode.fromTree(fromTree).get.normalizeTypes()
+ val to = ENode.fromTree(toTree).get.normalizeTypes()
+ val fromCanonical = egraph.canonicalize(from)
+ val toCanonical = egraph.canonicalize(to)
+ egraph.merge(fromCanonical, egraph.trueNode)
+ egraph.repair()
+ assertStringEquals(egraphString, egraph.debugString())
+ val res = egraph.equiv(toCanonical, egraph.trueNode)
+ assertEquals(s"Expected $fromString --> $toString to be $expected", expected, res)
+
+ def checkNotImplies(fromString: String, toString: String, egraphString: String): Unit =
+ checkImplies(fromString, toString, egraphString, expected = false)
+
+ @Test def test1() =
+ checkImplies(
+ "true",
+ "true",
+ """-1: {}
+ |0: {}
+ |1: {}
+ |false: {}
+ |true: {}
+ |""".stripMargin
+ )
+
+ @Test def test2() =
+ checkImplies(
+ "b1",
+ "b1",
+ """-1: {}
+ |0: {}
+ |1: {}
+ |false: {}
+ |true: {b1}
+ |""".stripMargin
+ )
+
+ @Test def test3() =
+ checkNotImplies(
+ "b1",
+ "b2",
+ """-1: {}
+ |0: {}
+ |1: {}
+ |b2: {}
+ |false: {}
+ |true: {b1}
+ |""".stripMargin
+ )
+
+ @Test def test4() =
+ checkImplies(
+ "b1 && b2",
+ "b2",
+ """-1: {}
+ |0: {}
+ |1: {}
+ |false: {}
+ |true: {b1, b1 && b2, b2}
+ |""".stripMargin
+ )
+
+ @Test def test5() =
+ checkNotImplies(
+ "b1 || b2",
+ "b2",
+ """-1: {}
+ |0: {}
+ |1: {}
+ |b1: {}
+ |b2: {}
+ |false: {}
+ |true: {b1 || b2}
+ |""".stripMargin
+ )
+
+ @Test def test6() =
+ checkImplies(
+ "b1 && b2 && b3",
+ "b3",
+ """-1: {}
+ |0: {}
+ |1: {}
+ |false: {}
+ |true: {b1, b1 && b2, b1 && b2 && b3, b2, b3}
+ |""".stripMargin
+ )
+
+ @Test def test7() =
+ checkImplies(
+ "b1 && b1 == b2",
+ "b2",
+ """-1: {}
+ |0: {}
+ |1: {}
+ |false: {}
+ |true: {b1, b1 && b1 == b2, b1 == b2, b2}
+ |""".stripMargin
+ )
+
+ @Test def test8() =
+ checkImplies(
+ "b1 && b1 == b2 && b2 == b3",
+ "b3",
+ """-1: {}
+ |0: {}
+ |1: {}
+ |false: {}
+ |true: {b1, b1 && b1 == b2, b1 && b1 == b2 && b2 == b3, b1 == b2, b2, b2 == b3, b3}
+ |""".stripMargin
+ )
+
+ @Test def test9() =
+ checkImplies(
+ "f(x) && x == y",
+ "f(y)",
+ """-1: {}
+ |0: {}
+ |1: {}
+ |f: {}
+ |false: {}
+ |true: {f(x), f(x) && x == y, f(y), x == y}
+ |x: {y}
+ |""".stripMargin
+ )
+
+ @Test def nestedFunctions() =
+ checkImplies(
+ "f(g(x)) && g(x) == g(y)",
+ "f(g(y))",
+ """-1: {}
+ |0: {}
+ |1: {}
+ |f: {}
+ |false: {}
+ |g: {}
+ |g(x): {g(y)}
+ |true: {f(g(x)), f(g(x)) && g(x) == g(y), f(g(y)), g(x) == g(y)}
+ |x: {}
+ |y: {}
+ |""".stripMargin
+ )
+
+ @Test def multipleArgs() =
+ checkImplies(
+ "y == z",
+ "h(x, y) == h(x, z)",
+ """-1: {}
+ |0: {}
+ |1: {}
+ |false: {}
+ |h: {}
+ |h(x, y): {h(x, z)}
+ |true: {h(x, y) == h(x, z), y == z}
+ |x: {}
+ |y: {z}
+ |""".stripMargin
+ )
+
+ @Test def multipleArgsDeep() =
+ checkImplies(
+ "f(h(x, y)) && y == z",
+ "f(h(x, z))",
+ """-1: {}
+ |0: {}
+ |1: {}
+ |f: {}
+ |false: {}
+ |h: {}
+ |h(x, y): {h(x, z)}
+ |true: {f(h(x, y)), f(h(x, y)) && y == z, f(h(x, z)), y == z}
+ |x: {}
+ |y: {z}
+ |""".stripMargin
+ )
+
+ @Test def sizeSum() =
+ checkImplies(
+ "len(v1) == len(v2) + len(v3) && len(v2) == 3 && len(v3) == 4",
+ "len(v1) == 7",
+ """-1: {}
+ |0: {}
+ |1: {}
+ |3: {len[Int](v2)}
+ |4: {len[Int](v3)}
+ |7: {len[Int](v1), len[Int](v2) + len[Int](v3)}
+ |false: {}
+ |len: {}
+ |len[Int]: {}
+ |true: {len[Int](v1) == 7, len[Int](v1) == len[Int](v2) + len[Int](v3), len[Int](v1) == len[Int](v2) + len[Int](v3) && len[Int](v2) == 3, len[Int](v1) == len[Int](v2) + len[Int](v3) && len[Int](v2) == 3 && len[Int](v3) == 4, len[Int](v2) + len[Int](v3) == 7, len[Int](v2) == 3, len[Int](v3) == 4}
+ |v1: {}
+ |v2: {}
+ |v3: {}
+ |""".stripMargin
+ )
diff --git a/compiler/test/dotty/tools/dotc/qualified_types/ENodeTest.scala b/compiler/test/dotty/tools/dotc/qualified_types/ENodeTest.scala
new file mode 100644
index 000000000000..24051d09a718
--- /dev/null
+++ b/compiler/test/dotty/tools/dotc/qualified_types/ENodeTest.scala
@@ -0,0 +1,60 @@
+package dotty.tools.dotc.qualified_types
+
+import dotty.tools.DottyTest
+import dotty.tools.dotc.ast.tpd
+import dotty.tools.dotc.core.Contexts.Context
+
+import org.junit.Assert.assertEquals
+import org.junit.Test
+
+class ENodeTest extends QualifiedTypesTest:
+
+ def checkFromToTree(exprString: String, resultString: String): Unit =
+ checkCompileExpr(s"val v = $exprString"): stats =>
+ val tree1: tpd.Tree = getValDef(stats, "v").rhs
+ val enode: ENode = ENode.fromTree(tree1).get
+ assertStringEquals(resultString, enode.show)
+ val tree2: tpd.Tree = enode.toTree()
+ assertStringEquals(tree1.show, tree2.show)
+
+ @Test def testFromToTree1() =
+ checkFromToTree(
+ "(param0: Int) => param0",
+ "(_: Int) => eparam0"
+ )
+
+ @Test def testFromToTree2() =
+ checkFromToTree(
+ "(param0: Int) => param0 + 1",
+ "(_: Int) => eparam0 + 1"
+ )
+
+ @Test def testFromToTree3() =
+ // ENode.fromTree and ENode#toTree do not perform constant folding or
+ // normalization. This is only done when adding E-Nodes to an E-Graph.
+ checkFromToTree(
+ "(param0: Int) => param0 + 1 + 1",
+ "(_: Int) => eparam0 + 1 + 1"
+ )
+
+ @Test def testFromToTree4() =
+ checkFromToTree(
+ "(param0: Int) => (param1: Int) => param0 + param1",
+ // In De Bruijn notation the outermost parameter is param1 and the
+ // innermost param0
+ "(_: Int) => (_: Int) => eparam1 + eparam0"
+ )
+
+ @Test def testFromToTree5() =
+ checkFromToTree(
+ "(param0: Int, param1: Int) => param0 + param1",
+ // Same for paramter lists with multiple parameters: the outermost
+ // parameter is param1 and the innermost param0
+ "(_: Int, _: Int) => eparam1 + eparam0"
+ )
+
+ @Test def testFromToTree6() =
+ checkFromToTree(
+ "(param0: Int, param1: Int) => (param2: Int, param3: Int) => param0 + param1 + param2 + param3",
+ "(_: Int, _: Int) => (_: Int, _: Int) => eparam3 + eparam2 + eparam1 + eparam0"
+ )
diff --git a/compiler/test/dotty/tools/dotc/qualified_types/QualifiedTypesTest.scala b/compiler/test/dotty/tools/dotc/qualified_types/QualifiedTypesTest.scala
new file mode 100644
index 000000000000..5d0fbeb6849d
--- /dev/null
+++ b/compiler/test/dotty/tools/dotc/qualified_types/QualifiedTypesTest.scala
@@ -0,0 +1,41 @@
+package dotty.tools.dotc.qualified_types
+
+import dotty.tools.DottyTest
+import dotty.tools.dotc.ast.tpd
+import dotty.tools.dotc.core.Contexts.{Context, FreshContext}
+import dotty.tools.dotc.core.Decorators.i
+
+import org.junit.Assert.assertEquals
+import org.junit.runners.MethodSorters
+import org.junit.FixMethodOrder
+
+@FixMethodOrder(MethodSorters.JVM)
+abstract class QualifiedTypesTest extends DottyTest:
+
+ override protected def initializeCtx(fc: FreshContext): Unit =
+ super.initializeCtx(fc)
+ fc.setSetting(fc.settings.XnoEnrichErrorMessages, true)
+ fc.setSetting(fc.settings.color, "never")
+ fc.setSetting(fc.settings.language, List("experimental.qualifiedTypes").asInstanceOf)
+
+ def checkCompileExpr(statsString: String)(assertion: List[tpd.Tree] => Context ?=> Unit): Unit =
+ checkCompile("typer", s"object Test { $statsString }"): (pkg, context) =>
+ given Context = context
+ val packageStats = pkg.asInstanceOf[tpd.PackageDef].stats
+ val clazz = getTypeDef(packageStats, "Test$")
+ val clazzStats = clazz.rhs.asInstanceOf[tpd.Template].body
+ assertion(clazzStats)(using context)
+
+ def getTypeDef(trees: List[tpd.Tree], name: String)(using Context): tpd.TypeDef =
+ trees.collectFirst { case td: tpd.TypeDef if td.name.toString() == name => td }.get
+
+ def getValDef(trees: List[tpd.Tree], name: String)(using Context): tpd.ValDef =
+ trees.collectFirst { case vd: tpd.ValDef if vd.name.toString() == name => vd }.get
+
+ def getDefDef(trees: List[tpd.Tree], name: String)(using Context): tpd.DefDef =
+ trees.collectFirst { case vd: tpd.DefDef if vd.name.toString() == name => vd }.get
+
+ def assertStringEquals(expected: String, found: String)(using Context): Unit =
+ val formattedExpected = if expected.contains('\n') then "\n" + expected.linesIterator.map(" " + _).mkString("\n") else expected
+ val formattedFound = if found.contains('\n') then "\n" + found.linesIterator.map(" " + _).mkString("\n") else found
+ assertEquals(s"\n Expected: $formattedExpected\n Found: $formattedFound\n", expected, found)
diff --git a/library/src/scala/annotation/qualified.scala b/library/src/scala/annotation/qualified.scala
new file mode 100644
index 000000000000..0c0b6532dd43
--- /dev/null
+++ b/library/src/scala/annotation/qualified.scala
@@ -0,0 +1,4 @@
+package scala.annotation
+
+/** Annotation for qualified types. */
+@experimental class qualified[T](predicate: T => Boolean) extends RefiningAnnotation
diff --git a/library/src/scala/language.scala b/library/src/scala/language.scala
index 63941a86bd67..b6ca43b8e067 100644
--- a/library/src/scala/language.scala
+++ b/library/src/scala/language.scala
@@ -286,6 +286,10 @@ object language {
@compileTimeOnly("`separationChecking` can only be used at compile time in import statements")
object separationChecking
+ /** Experimental support for qualified types */
+ @compileTimeOnly("`qualifiedTypes` is only be used at compile time")
+ object qualifiedTypes
+
/** Experimental support for automatic conversions of arguments, without requiring
* a language import `import scala.language.implicitConversions`.
*
diff --git a/library/src/scala/runtime/stdLibPatches/language.scala b/library/src/scala/runtime/stdLibPatches/language.scala
index c4da436a78d8..b1e38756493d 100644
--- a/library/src/scala/runtime/stdLibPatches/language.scala
+++ b/library/src/scala/runtime/stdLibPatches/language.scala
@@ -93,6 +93,10 @@ object language:
@compileTimeOnly("`separationChecking` can only be used at compile time in import statements")
object separationChecking
+ /** Experimental support for qualified types */
+ @compileTimeOnly("`qualifiedTypes` is only be used at compile time")
+ object qualifiedTypes
+
/** Experimental support for automatic conversions of arguments, without requiring
* a language import `import scala.language.implicitConversions`.
*
diff --git a/project/Build.scala b/project/Build.scala
index 8c233af8b2b3..e573dbb320ae 100644
--- a/project/Build.scala
+++ b/project/Build.scala
@@ -1085,6 +1085,7 @@ object Build {
file(s"${baseDirectory.value}/src/scala/annotation/MacroAnnotation.scala"),
file(s"${baseDirectory.value}/src/scala/annotation/alpha.scala"),
file(s"${baseDirectory.value}/src/scala/annotation/publicInBinary.scala"),
+ file(s"${baseDirectory.value}/src/scala/annotation/qualified.scala"),
file(s"${baseDirectory.value}/src/scala/annotation/init.scala"),
file(s"${baseDirectory.value}/src/scala/annotation/unroll.scala"),
file(s"${baseDirectory.value}/src/scala/annotation/targetName.scala"),
@@ -1224,6 +1225,7 @@ object Build {
file(s"${baseDirectory.value}/src/scala/annotation/MacroAnnotation.scala"),
file(s"${baseDirectory.value}/src/scala/annotation/alpha.scala"),
file(s"${baseDirectory.value}/src/scala/annotation/publicInBinary.scala"),
+ file(s"${baseDirectory.value}/src/scala/annotation/qualified.scala"),
file(s"${baseDirectory.value}/src/scala/annotation/init.scala"),
file(s"${baseDirectory.value}/src/scala/annotation/unroll.scala"),
file(s"${baseDirectory.value}/src/scala/annotation/targetName.scala"),
diff --git a/project/MiMaFilters.scala b/project/MiMaFilters.scala
index e2ca983081d2..809a9acfc593 100644
--- a/project/MiMaFilters.scala
+++ b/project/MiMaFilters.scala
@@ -32,6 +32,9 @@ object MiMaFilters {
ProblemFilters.exclude[MissingClassProblem]("scala.caps.Classifier"),
ProblemFilters.exclude[MissingClassProblem]("scala.caps.SharedCapability"),
ProblemFilters.exclude[MissingClassProblem]("scala.caps.Control"),
+
+ ProblemFilters.exclude[MissingFieldProblem]("scala.runtime.stdLibPatches.language#experimental.qualifiedTypes"),
+ ProblemFilters.exclude[MissingClassProblem]("scala.runtime.stdLibPatches.language$experimental$qualifiedTypes$"),
),
// Additions since last LTS
diff --git a/tests/neg-custom-args/qualified-types/adapt_neg.scala b/tests/neg-custom-args/qualified-types/adapt_neg.scala
new file mode 100644
index 000000000000..8722ffcbb7b8
--- /dev/null
+++ b/tests/neg-custom-args/qualified-types/adapt_neg.scala
@@ -0,0 +1,22 @@
+def f(x: Int): Int = ???
+case class IntBox(x: Int)
+case class Box[T](x: T)
+
+def test: Unit =
+ val x: Int = ???
+ val y: Int = ???
+ def g(x: Int): Int = ???
+
+ val v1: {v: Int with v == 1} = 2 // error
+ val v2: {v: Int with v == x} = y // error
+ val v3: {v: Int with v == x + 1} = x + 2 // error
+ val v4: {v: Int with v == f(x)} = g(x) // error
+ val v5: {v: Int with v == g(x)} = f(x) // error
+ val v6: {v: IntBox with v == IntBox(x)} = IntBox(x + 1) // error
+ val v7: {v: Box[Int] with v == Box(x)} = Box(x + 1) // error
+ val v8: {v: Int with v == x + f(x)} = x + g(x) // error
+ val v9: {v: Int with v == x + g(x)} = x + f(x) // error
+ val v10: {v: Int with v == f(x + 1)} = f(x + 2) // error
+ val v11: {v: Int with v == g(x + 1)} = g(x + 2) // error
+ val v12: {v: IntBox with v == IntBox(x + 1)} = IntBox(x) // error
+ val v13: {v: Box[Int] with v == Box(x + 1)} = Box(x) // error
diff --git a/tests/neg-custom-args/qualified-types/list_apply_neg.scala b/tests/neg-custom-args/qualified-types/list_apply_neg.scala
new file mode 100644
index 000000000000..1c3eedda0291
--- /dev/null
+++ b/tests/neg-custom-args/qualified-types/list_apply_neg.scala
@@ -0,0 +1,4 @@
+type PosInt = {v: Int with v > 0}
+
+@main def Test =
+ val l: List[PosInt] = List(1,-2,3) // error // error
diff --git a/tests/neg-custom-args/qualified-types/runtimeChecked_dependent_neg.scala b/tests/neg-custom-args/qualified-types/runtimeChecked_dependent_neg.scala
new file mode 100644
index 000000000000..f8fcc5c576f0
--- /dev/null
+++ b/tests/neg-custom-args/qualified-types/runtimeChecked_dependent_neg.scala
@@ -0,0 +1,8 @@
+def foo(x: Int, y: {v: Int with v > x}): y.type = y
+
+def getInt(): Int =
+ println("getInt called")
+ 42
+
+@main def Test =
+ val res = foo(getInt(), 2.runtimeChecked) // error
diff --git a/tests/neg-custom-args/qualified-types/subtyping_egraph_state.scala b/tests/neg-custom-args/qualified-types/subtyping_egraph_state.scala
new file mode 100644
index 000000000000..a40f6db5331c
--- /dev/null
+++ b/tests/neg-custom-args/qualified-types/subtyping_egraph_state.scala
@@ -0,0 +1,7 @@
+def test: Unit =
+ val b: Boolean = ???
+ val b2: Boolean = ???
+ summon[{u: Unit with b && b2} <:< {u: Unit with b}]
+ // Checks that E-Graph state is reset after the implication check: b is no
+ // longer true
+ summon[{u: Unit with true} <:< {u: Unit with b}] // error
diff --git a/tests/neg-custom-args/qualified-types/subtyping_lambdas_neg.scala.scala b/tests/neg-custom-args/qualified-types/subtyping_lambdas_neg.scala.scala
new file mode 100644
index 000000000000..c36d13c0941f
--- /dev/null
+++ b/tests/neg-custom-args/qualified-types/subtyping_lambdas_neg.scala.scala
@@ -0,0 +1,9 @@
+def toBool[T](x: T): Boolean = ???
+def tp[T](): Any = ???
+
+def test: Unit =
+ val x: {l: List[Int] with toBool((x: String, y: x.type) => x.length > 0)} = ??? // error: cannot turn method type into closure because it has internal parameter dependencies
+ summon[{l: List[Int] with toBool((x: String, y: String) => tp[x.type]())} =:= {l: List[Int] with toBool((x: String, y: String) => tp[y.type]())}] // error
+
+ summon[{l: List[Int] with toBool((x: Double) => (y: Int) => x == y)} =:= {l: List[Int] with toBool((a: Double) => (b: Int) => a == a)}] // error
+ summon[{l: List[Int] with toBool((x: Int) => (y: Int) => x == y)} =:= {l: List[Int] with toBool((a: Int) => (b: Int) => a == a)}] // error
diff --git a/tests/neg-custom-args/qualified-types/subtyping_objects_neg.scala b/tests/neg-custom-args/qualified-types/subtyping_objects_neg.scala
new file mode 100644
index 000000000000..9a0e444232b2
--- /dev/null
+++ b/tests/neg-custom-args/qualified-types/subtyping_objects_neg.scala
@@ -0,0 +1,52 @@
+class Box[T](val x: T)
+
+class BoxMutable[T](var x: T)
+
+class Foo(val id: String):
+ def this(x: Int) = this(x.toString)
+
+class Person(val name: String, val age: Int)
+
+class PersonCurried(val name: String)(val age: Int)
+
+class PersonMutable(val name: String, var age: Int)
+
+case class PersonCaseMutable(name: String, var age: Int)
+
+case class PersonCaseSecondary(name: String, age: Int):
+ def this(name: String) = this(name, 0)
+
+case class PersonCaseEqualsOverriden(name: String, age: Int):
+ override def equals(that: Object): Boolean = this eq that
+
+def test: Unit =
+ summon[{b: Box[Int] with b == Box(1)} =:= {b: Box[Int] with b == Box(1)}] // error // error
+
+ summon[{b: BoxMutable[Int] with b == BoxMutable(1)} =:= {b: BoxMutable[Int] with b == BoxMutable(1)}] // error // error
+ // TODO(mbovel): restrict selection to stable members
+ //summon[{b: BoxMutable[Int] with b.x == 3} =:= {b: BoxMutable[Int] with b.x == 3}]
+
+ summon[{f: Foo with f == Foo("hello")} =:= {f: Foo with f == Foo("hello")}] // error // error
+ summon[{f: Foo with f == Foo(1)} =:= {f: Foo with f == Foo(1)}] // error // error
+ summon[{s: String with Foo("hello").id == s} =:= {s: String with s == "hello"}] // error
+
+ summon[{p: Person with p == Person("Alice", 30)} =:= {p: Person with p == Person("Alice", 30)}] // error // error
+ summon[{s: String with Person("Alice", 30).name == s} =:= {s: String with s == "Alice"}] // error
+ summon[{n: Int with Person("Alice", 30).age == n} =:= {n: Int with n == 30}] // error
+
+ summon[{p: PersonCurried with p == PersonCurried("Alice")(30)} =:= {p: PersonCurried with p == PersonCurried("Alice")(30)}] // error // error
+ summon[{s: String with PersonCurried("Alice")(30).name == s} =:= {s: String with s == "Alice"}] // error
+ summon[{n: Int with PersonCurried("Alice")(30).age == n} =:= {n: Int with n == 30}] // error
+
+ summon[{p: PersonMutable with p == PersonMutable("Alice", 30)} =:= {p: PersonMutable with p == PersonMutable("Alice", 30)}] // error // error
+ summon[{s: String with PersonMutable("Alice", 30).name == s} =:= {s: String with s == "Alice"}] // error
+ summon[{n: Int with PersonMutable("Alice", 30).age == n} =:= {n: Int with n == 30}] // error
+
+ summon[{n: Int with PersonCaseMutable("Alice", 30).age == n} =:= {n: Int with n == 30}] // error
+
+ summon[{s: String with new PersonCaseSecondary("Alice").name == s} =:= {s: String with s == "Alice"}] // error
+ summon[{n: Int with new PersonCaseSecondary("Alice").age == n} =:= {n: Int with n == 0}] // error
+
+ summon[{p: PersonCaseEqualsOverriden with PersonCaseEqualsOverriden("Alice", 30) == p} =:= {p: PersonCaseEqualsOverriden with p == PersonCaseEqualsOverriden("Alice", 30)}] // error // error
+ summon[{s: String with PersonCaseEqualsOverriden("Alice", 30).name == s} =:= {s: String with s == "Alice"}] // error
+ summon[{n: Int with PersonCaseEqualsOverriden("Alice", 30).age == n} =:= {n: Int with n == 30}] // error
diff --git a/tests/neg-custom-args/qualified-types/subtyping_singletons_neg.scala b/tests/neg-custom-args/qualified-types/subtyping_singletons_neg.scala
new file mode 100644
index 000000000000..421cca0789b2
--- /dev/null
+++ b/tests/neg-custom-args/qualified-types/subtyping_singletons_neg.scala
@@ -0,0 +1,8 @@
+def f(x: Int): Int = ???
+
+def test: Unit =
+ val x: Int = ???
+ val y: Int = ???
+ summon[2 <:< {v: Int with v == 1}] // error
+ summon[x.type <:< {v: Int with v == 1}] // error
+ //summon[y.type <:< {v: Int with v == x}] // FIXME
diff --git a/tests/neg-custom-args/qualified-types/subtyping_unfolding_neg.scala b/tests/neg-custom-args/qualified-types/subtyping_unfolding_neg.scala
new file mode 100644
index 000000000000..2af09733af41
--- /dev/null
+++ b/tests/neg-custom-args/qualified-types/subtyping_unfolding_neg.scala
@@ -0,0 +1,18 @@
+def tp[T](): T = ???
+
+abstract class C:
+ type T
+ val x: T
+
+def test: Unit =
+ val x: Int = ???
+ val z: Int = ???
+ val c1: C = ???
+ val c2: C = ???
+
+ summon[{v: Int with v == x} <:< {v: Int with v == z}] // error
+ summon[{v: C with v == c1} <:< {v: C with v == c2}] // error
+
+ // summon[{v: Int with v == (??? : Int)} <:< {v: Int with v == (??? : Int)}] // TODO(mbovel): should not compare some impure applications?
+
+ summon[{v: Int with v == tp[c1.T]()} <:< {v: Int with v == tp[c2.T]()}] // error
diff --git a/tests/neg-custom-args/qualified-types/syntax_unnamed_neg.scala b/tests/neg-custom-args/qualified-types/syntax_unnamed_neg.scala
new file mode 100644
index 000000000000..9191cf5db2db
--- /dev/null
+++ b/tests/neg-custom-args/qualified-types/syntax_unnamed_neg.scala
@@ -0,0 +1,7 @@
+case class Box[T](x: T)
+def id[T](x: T): T = x
+
+abstract class Test:
+ val v1: Box[Int with v1 > 0] // error: Cyclic reference
+ val v2: {v: Int with id[Int with v2 > 0](???) > 0} // error: Cyclic reference
+ val v3: {v: Int with (??? : Int with v3 == 2) > 0} // error: Cyclic reference
diff --git a/tests/pos-custom-args/qualified-types/adapt.scala b/tests/pos-custom-args/qualified-types/adapt.scala
new file mode 100644
index 000000000000..597c6906b19d
--- /dev/null
+++ b/tests/pos-custom-args/qualified-types/adapt.scala
@@ -0,0 +1,21 @@
+def f(x: Int): Int = ???
+case class IntBox(x: Int)
+case class Box[T](x: T)
+
+def f(x: Int, y: Int): {r: Int with r == x + y} = x + y
+
+def test: Unit =
+ val x: Int = ???
+ def g(x: Int): Int = ???
+
+ val v1: {v: Int with v == x + 1} = x + 1
+ val v2: {v: Int with v == f(x)} = f(x)
+ val v3: {v: Int with v == g(x)} = g(x)
+ val v4: {v: IntBox with v == IntBox(x)} = IntBox(x)
+ val v5: {v: Box[Int] with v == Box(x)} = Box(x)
+ val v6: {v: Int with v == x + f(x)} = x + f(x)
+ val v7: {v: Int with v == x + g(x)} = x + g(x)
+ val v8: {v: Int with v == f(x + 1)} = f(x + 1)
+ val v9: {v: Int with v == g(x + 1)} = g(x + 1)
+ val v12: {v: IntBox with v == IntBox(x + 1)} = IntBox(x + 1)
+ val v13: {v: Box[Int] with v == Box(x + 1)} = Box(x + 1)
diff --git a/tests/pos-custom-args/qualified-types/avoidance.scala b/tests/pos-custom-args/qualified-types/avoidance.scala
new file mode 100644
index 000000000000..bff6decb61a2
--- /dev/null
+++ b/tests/pos-custom-args/qualified-types/avoidance.scala
@@ -0,0 +1,6 @@
+def Test = ()
+
+ //val x =
+ // val y: Int = ???
+ // y: {v: Int with v == y}
+ // TODO(mbovel): proper avoidance for qualified types
diff --git a/tests/pos-custom-args/qualified-types/class_constraints.scala b/tests/pos-custom-args/qualified-types/class_constraints.scala
new file mode 100644
index 000000000000..71085de63d44
--- /dev/null
+++ b/tests/pos-custom-args/qualified-types/class_constraints.scala
@@ -0,0 +1,3 @@
+/*class foo(elem: Int with elem > 0)*/
+
+@main def Test = ()
diff --git a/tests/pos-custom-args/qualified-types/list_map.scala b/tests/pos-custom-args/qualified-types/list_map.scala
new file mode 100644
index 000000000000..38b4a73af647
--- /dev/null
+++ b/tests/pos-custom-args/qualified-types/list_map.scala
@@ -0,0 +1,9 @@
+type PosInt = {v: Int with v > 0}
+
+
+def inc(x: PosInt): PosInt = (x + 1).runtimeChecked
+
+@main def Test =
+ val l: List[PosInt] = List(1,2,3)
+ val l2: List[PosInt] = l.map(inc)
+ ()
diff --git a/tests/pos-custom-args/qualified-types/sized_lists.scala b/tests/pos-custom-args/qualified-types/sized_lists.scala
new file mode 100644
index 000000000000..529079debdde
--- /dev/null
+++ b/tests/pos-custom-args/qualified-types/sized_lists.scala
@@ -0,0 +1,11 @@
+def size(v: Vec): Int = ???
+type Vec
+
+def vec(s: Int): {v: Vec with size(v) == s} = ???
+def concat(v1: Vec, v2: Vec): {v: Vec with size(v) == size(v1) + size(v2)} = ???
+def sum(v1: Vec, v2: Vec with size(v1) == size(v2)): {v: Vec with size(v) == size(v1)} = ???
+
+@main def Test =
+ val v3: {v: Vec with size(v) == 3} = vec(3)
+ val v4: {v: Vec with size(v) == 4} = vec(4)
+ val v7: {v: Vec with size(v) == 7} = concat(v3, v4)
diff --git a/tests/pos-custom-args/qualified-types/sized_lists2.scala b/tests/pos-custom-args/qualified-types/sized_lists2.scala
new file mode 100644
index 000000000000..5b4b1576bada
--- /dev/null
+++ b/tests/pos-custom-args/qualified-types/sized_lists2.scala
@@ -0,0 +1,22 @@
+type Vec[T]
+object Vec:
+ def fill[T](n: Int, v: T):
+ {r: Vec[T] with r.len == n}
+ = ???
+extension [T](a: Vec[T])
+ def len: {r: Int with r >= 0} = ???
+ def concat(b: Vec[T]):
+ {r: Vec[T] with r.len == a.len + b.len}
+ = ???
+ def zip[S](b: Vec[S] with b.len == a.len):
+ {r: Vec[(T, S)] with r.len == a.len}
+ = ???
+
+@main def Test =
+ val n: Int with n >= 0 = ???
+ val m: Int with m >= 0 = ???
+ val v1 = Vec.fill(n, 0)
+ val v2 = Vec.fill(m, 1)
+ val v3 = v1.concat(v2)
+ val mPlusN = m + n
+ val v4: {r: Vec[(String, Int)] with r.len == mPlusN} = Vec.fill(mPlusN, "").zip(v3)
diff --git a/tests/pos-custom-args/qualified-types/subtyping_equality.scala b/tests/pos-custom-args/qualified-types/subtyping_equality.scala
new file mode 100644
index 000000000000..4b4c8d5e8989
--- /dev/null
+++ b/tests/pos-custom-args/qualified-types/subtyping_equality.scala
@@ -0,0 +1,48 @@
+def f(x: Int): Int = ???
+def g(x: Int): Int = ???
+def f2(x: Int, y: Int): Int = ???
+def g2(x: Int, y: Int): Int = ???
+
+case class IntBox(x: Int)
+case class Box[T](x: T)
+
+def test: Unit =
+ val a: Int = ???
+ val b: Int = ???
+ val c: Int = ???
+ val d: Int = ???
+
+ summon[{v: Int with v == 2} <:< {v: Int with v == 2}]
+ summon[{v: Int with v == f(a)} <:< {v: Int with v == f(a)}]
+
+ // Equality is reflexive, symmetric and transitive
+ summon[{v: Int with v == v} <:< {v: Int with true}]
+ summon[{v: Int with a == b} <:< {v: Int with true}]
+ summon[{v: Int with v == a} <:< {v: Int with v == a}]
+ summon[{v: Int with v == a} <:< {v: Int with a == v}]
+ summon[{v: Int with a == b} <:< {v: Int with b == a}]
+ summon[{v: Int with v == a && a > 3} <:< {v: Int with v > 3}]
+ summon[{v: Int with v == a && a == b} <:< {v: Int with v == b}]
+ summon[{v: Int with a == b && b == c} <:< {v: Int with a == c}]
+ summon[{v: Int with a == b && c == b} <:< {v: Int with a == c}]
+ summon[{v: Int with a == b && c == d && b == d} <:< {v: Int with b == d}]
+ summon[{v: Int with a == b && c == d && b == d} <:< {v: Int with a == c}]
+
+ // Equality is congruent over functions
+ summon[{v: Int with a == b} <:< {v: Int with f(a) == f(b)}]
+ summon[{v: Int with a == b} <:< {v: Int with f(f(a)) == f(f(b))}]
+ summon[{v: Int with c == f(a) && d == f(b) && a == b} <:< {v: Int with c == d}]
+ // the two first equalities in the premises are just used to test the behavior
+ // of the e-graph when `f(a)` and `f(b)` are inserted before `a == b`.
+ summon[{v: Int with c == f(a) && d == f(b) && a == b} <:< {v: Int with f(a) == f(b)}]
+ summon[{v: Int with c == f(a) && d == f(b) && a == b} <:< {v: Int with f(f(a)) == f(f(b))}]
+
+ // Equality is supported on Strings
+ summon[{v: String with v == "hello"} <:< {v: String with v == "hello"}]
+ summon[{v: String with v == "hello"} <:< {v: String with "hello" == v}]
+
+ // Equality is supported on case classes
+ summon[{v: IntBox with v == IntBox(3)} <:< {v: IntBox with v == IntBox(3)}]
+ summon[{v: IntBox with v == IntBox(3)} <:< {v: IntBox with IntBox(3) == v}]
+ summon[{v: Box[Int] with v == Box(3)} <:< {v: Box[Int] with v == Box(3)}]
+ summon[{v: Box[Int] with v == Box(3)} <:< {v: Box[Int] with Box(3) == v}]
diff --git a/tests/pos-custom-args/qualified-types/subtyping_lambdas.scala b/tests/pos-custom-args/qualified-types/subtyping_lambdas.scala
new file mode 100644
index 000000000000..39807c28a840
--- /dev/null
+++ b/tests/pos-custom-args/qualified-types/subtyping_lambdas.scala
@@ -0,0 +1,17 @@
+def toBool[T](x: T): Boolean = ???
+def tp[T](): Any = ???
+
+
+def test: Unit =
+ val v1: {l: List[Int] with l.forall(x => x > 0)} = ??? : {l: List[Int] with l.forall(x => x > 0)}
+ val v2: {l: List[Int] with l.forall(x => x > 0)} = ??? : {l: List[Int] with l.forall(y => y > 0)}
+ val v3: {l: List[Int] with l.forall(x => x > 0)} = ??? : {l: List[Int] with l.forall(_ > 0)}
+
+ val v4: {l: List[Int] with toBool((x: String) => x.length > 0)} = ??? : {l: List[Int] with toBool((y: String) => y.length > 0)}
+
+ val v5: {l: List[Int] with toBool((x: String) => tp[x.type]())} = ??? : {l: List[Int] with toBool((y: String) => tp[y.type]())}
+ val v6: {l: List[Int] with toBool((x: String, y: String) => tp[x.type]())} = ??? : {l: List[Int] with toBool((x: String, y: String) => tp[x.type]())}
+ val v7: {l: List[Int] with toBool((x: String) => tp[x.type]())} = ??? : {l: List[Int] with toBool((y: String) => tp[y.type]())}
+ val v8: {l: List[Int] with toBool((x: String, y: String) => tp[y.type]())} = ??? : {l: List[Int] with toBool((x: String, y: String) => tp[y.type]())}
+
+ val v9: {l: List[Int] with toBool((x: String) => (y: String) => x == y)} = ??? : {l: List[Int] with toBool((a: String) => (b: String) => a == b)}
diff --git a/tests/pos-custom-args/qualified-types/subtyping_normalization.scala b/tests/pos-custom-args/qualified-types/subtyping_normalization.scala
new file mode 100644
index 000000000000..bb961712510a
--- /dev/null
+++ b/tests/pos-custom-args/qualified-types/subtyping_normalization.scala
@@ -0,0 +1,29 @@
+def f(x: Int): Int = ???
+def id[T](x: T): T = x
+def opaqueSize[T](l: List[T]): Int = ???
+
+def test: Unit =
+ val x: Int = ???
+ val y: Int = ???
+ val z: Int = ???
+
+ summon[{v: Int with v == 2 + (x * y * y * z)} <:< {v: Int with v == (x * y * z * y) + 2}]
+ summon[{v: Int with v == x + 1} <:< {v: Int with v == 1 + x}]
+ summon[{v: Int with v == y + x} <:< {v: Int with v == x + y}]
+ summon[{v: Int with v == x + 2} <:< {v: Int with v == 1 + x + 1}]
+ summon[{v: Int with v == x + 2} <:< {v: Int with v == 1 + (x + 1)}]
+ summon[{v: Int with v == x + 2 * y} <:< {v: Int with v == y + x + y}]
+ summon[{v: Int with v == x + 2 * y} <:< {v: Int with v == y + (x + y)}]
+ summon[{v: Int with v == x + 3 * y} <:< {v: Int with v == 2 * y + x + y}]
+ summon[{v: Int with v == x + 3 * y} <:< {v: Int with v == 2 * y + (x + y)}]
+ summon[{v: Int with v == 0} <:< {v: Int with v == 1 - 1}]
+ summon[{v: Int with v == 0} <:< {v: Int with v == x - x}]
+ summon[{v: Int with v == 0} <:< {v: Int with v == x + (x * -1)}]
+ summon[{v: Int with v == x} <:< {v: Int with v == 1 + x - 1}]
+ summon[{v: Int with v == 4 * (x + 1)} <:< {v: Int with v == 2 * (x + 1) + 2 * (1 + x)}]
+ summon[{v: Int with v == 4 * (x / 2)} <:< {v: Int with v == 2 * (x / 2) + 2 * (x / 2)}]
+
+ summon[{v: Int with v == id(x + 1)} <:< {v: Int with v == id(1 + x)}]
+ summon[{v: Int with v == id(x + 1)} <:< {v: Int with v == id(x + 1)}]
+
+ summon[{v: List[Int] with opaqueSize(v) == 2 * x} <:< {v: List[Int] with opaqueSize(v) == x + x}]
diff --git a/tests/pos-custom-args/qualified-types/subtyping_objects.scala b/tests/pos-custom-args/qualified-types/subtyping_objects.scala
new file mode 100644
index 000000000000..da3afae0ebd8
--- /dev/null
+++ b/tests/pos-custom-args/qualified-types/subtyping_objects.scala
@@ -0,0 +1,57 @@
+class Box[T](val x: T)
+
+class Foo(val id: String):
+ def this(x: Int) = this(x.toString)
+
+case class PersonCase(name: String, age: Int)
+
+case class PersonCaseCurried(name: String)(val age: Int)
+
+case class PersonCaseMutable(name: String, var age: Int)
+
+case class PersonCaseSecondary(name: String, age: Int):
+ def this(name: String) = this(name, 0)
+
+def test: Unit =
+ summon[{b: Box[Int] with 3 == b.x} =:= {b: Box[Int] with b.x == 3}]
+ summon[{f: Foo with f.id == "hello"} =:= {f: Foo with "hello" == f.id}]
+
+ // new PersonCase
+ summon[{p: PersonCase with p == new PersonCase("Alice", 30)} =:= {p: PersonCase with p == new PersonCase("Alice", 30)}]
+ summon[{s: String with new PersonCase("Alice", 30).name == s} =:= {s: String with s == "Alice"}]
+ summon[{n: Int with new PersonCase("Alice", 30).age == n} =:= {n: Int with n == 30}]
+
+ // PersonCase
+ summon[{p: PersonCase with p == PersonCase("Alice", 30)} =:= {p: PersonCase with p == PersonCase("Alice", 30)}]
+ summon[{s: String with PersonCase("Alice", 30).name == s} =:= {s: String with s == "Alice"}]
+ summon[{n: Int with PersonCase("Alice", 30).age == n} =:= {n: Int with n == 30}]
+
+ // new PersonCaseCurried
+ summon[{p: PersonCaseCurried with p == new PersonCaseCurried("Alice")(30)} =:= {p: PersonCaseCurried with p == new PersonCaseCurried("Alice")(30)}]
+ summon[{s: String with new PersonCaseCurried("Alice")(30).name == s} =:= {s: String with s == "Alice"}]
+ summon[{n: Int with new PersonCaseCurried("Alice")(30).age == n} =:= {n: Int with n == 30}]
+
+ // PersonCaseCurried
+ summon[{p: PersonCaseCurried with p == PersonCaseCurried("Alice")(30)} =:= {p: PersonCaseCurried with p == PersonCaseCurried("Alice")(30)}]
+ summon[{s: String with PersonCaseCurried("Alice")(30).name == s} =:= {s: String with s == "Alice"}]
+ summon[{n: Int with PersonCaseCurried("Alice")(30).age == n} =:= {n: Int with n == 30}]
+
+ // new PersonCaseMutable
+ summon[{p: PersonCaseMutable with p == new PersonCaseMutable("Alice", 30)} =:= {p: PersonCaseMutable with p == new PersonCaseMutable("Alice", 30)}]
+ summon[{s: String with new PersonCaseMutable("Alice", 30).name == s} =:= {s: String with s == "Alice"}]
+ //summon[{n: Int with new PersonCaseMutable("Alice", 30).age == n} =:= {n: Int with n == 30}] // error
+
+ // PersonCaseMutable
+ summon[{p: PersonCaseMutable with p == PersonCaseMutable("Alice", 30)} =:= {p: PersonCaseMutable with p == PersonCaseMutable("Alice", 30)}]
+ summon[{s: String with PersonCaseMutable("Alice", 30).name == s} =:= {s: String with s == "Alice"}]
+ //summon[{n: Int with PersonCaseMutable("Alice", 30).age == n} =:= {n: Int with n == 30}] // error
+
+ // new PersonCaseSecondary
+ summon[{p: PersonCaseSecondary with p == new PersonCaseSecondary("Alice")} =:= {p: PersonCaseSecondary with p == new PersonCaseSecondary("Alice")}]
+ //summon[{s: String with new PersonCaseSecondary("Alice").name == s} =:= {s: String with s == "Alice"}] // error
+ //summon[{n: Int with new PersonCaseSecondary("Alice").age == n} =:= {n: Int with n == 0}] // error
+
+ // PersonCaseSecondary
+ summon[{p: PersonCaseSecondary with p == PersonCaseSecondary("Alice", 30)} =:= {p: PersonCaseSecondary with p == PersonCaseSecondary("Alice", 30)}]
+ //summon[{s: String with PersonCaseSecondary("Alice", 30).name == s} =:= {s: String with s == "Alice"}] // error
+ //summon[{n: Int with PersonCaseSecondary("Alice", 30).age == n} =:= {n: Int with n == 30}] // error
diff --git a/tests/pos-custom-args/qualified-types/subtyping_paths.scala b/tests/pos-custom-args/qualified-types/subtyping_paths.scala
new file mode 100644
index 000000000000..153dfb2bee64
--- /dev/null
+++ b/tests/pos-custom-args/qualified-types/subtyping_paths.scala
@@ -0,0 +1,18 @@
+def tp[T](): Boolean = ???
+
+class Outer:
+ class Inner:
+ type D
+ summon[{v: Boolean with tp[Inner.this.D]()} =:= {v: Boolean with tp[D]()}]
+
+object OuterO:
+ object InnerO:
+ type D
+ summon[{v: Boolean with tp[InnerO.this.D]()} =:= {v: Boolean with tp[D]()}]
+
+ // Before normalization:
+ // lhs: