Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add primitive compiletime operations on singleton types #7628

Merged
merged 17 commits into from
Jan 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,11 @@ class Definitions {
@tu lazy val CompiletimeTesting_ErrorKind: Symbol = ctx.requiredModule("scala.compiletime.testing.ErrorKind")
@tu lazy val CompiletimeTesting_ErrorKind_Parser: Symbol = CompiletimeTesting_ErrorKind.requiredMethod("Parser")
@tu lazy val CompiletimeTesting_ErrorKind_Typer: Symbol = CompiletimeTesting_ErrorKind.requiredMethod("Typer")
@tu lazy val CompiletimeOpsPackageObject: Symbol = ctx.requiredModule("scala.compiletime.ops.package")
@tu lazy val CompiletimeOpsPackageObjectAny: Symbol = ctx.requiredModule("scala.compiletime.ops.package.any")
@tu lazy val CompiletimeOpsPackageObjectInt: Symbol = ctx.requiredModule("scala.compiletime.ops.package.int")
@tu lazy val CompiletimeOpsPackageObjectString: Symbol = ctx.requiredModule("scala.compiletime.ops.package.string")
@tu lazy val CompiletimeOpsPackageObjectBoolean: Symbol = ctx.requiredModule("scala.compiletime.ops.package.boolean")

/** The `scalaShadowing` package is used to safely modify classes and
* objects in scala so that they can be used from dotty. They will
Expand Down Expand Up @@ -898,6 +903,26 @@ class Definitions {
final def isCompiletime_S(sym: Symbol)(implicit ctx: Context): Boolean =
sym.name == tpnme.S && sym.owner == CompiletimePackageObject.moduleClass

private val compiletimePackageAnyTypes: Set[Name] = Set(tpnme.Equals, tpnme.NotEquals)
private val compiletimePackageIntTypes: Set[Name] = Set(
tpnme.Plus, tpnme.Minus, tpnme.Times, tpnme.Div, tpnme.Mod,
tpnme.Lt, tpnme.Gt, tpnme.Ge, tpnme.Le,
tpnme.Abs, tpnme.Negate, tpnme.Min, tpnme.Max, tpnme.ToString,
)
private val compiletimePackageBooleanTypes: Set[Name] = Set(tpnme.Not, tpnme.Xor, tpnme.And, tpnme.Or)
private val compiletimePackageStringTypes: Set[Name] = Set(tpnme.Plus)

final def isCompiletimeAppliedType(sym: Symbol)(implicit ctx: Context): Boolean = {
def isOpsPackageObjectAppliedType: Boolean =
sym.owner == CompiletimeOpsPackageObjectAny.moduleClass && compiletimePackageAnyTypes.contains(sym.name) ||
sym.owner == CompiletimeOpsPackageObjectInt.moduleClass && compiletimePackageIntTypes.contains(sym.name) ||
sym.owner == CompiletimeOpsPackageObjectBoolean.moduleClass && compiletimePackageBooleanTypes.contains(sym.name) ||
sym.owner == CompiletimeOpsPackageObjectString.moduleClass && compiletimePackageStringTypes.contains(sym.name)

sym.isType && (isCompiletime_S(sym) || isOpsPackageObjectAppliedType)
}


// ----- Symbol sets ---------------------------------------------------

@tu lazy val AbstractFunctionType: Array[TypeRef] = mkArityArray("scala.runtime.AbstractFunction", MaxImplementedFunctionArity, 0)
Expand Down
23 changes: 22 additions & 1 deletion compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,34 @@ object StdNames {
final val Product: N = "Product"
final val PartialFunction: N = "PartialFunction"
final val PrefixType: N = "PrefixType"
final val S: N = "S"
final val Serializable: N = "Serializable"
final val Singleton: N = "Singleton"
final val Throwable: N = "Throwable"
final val IOOBException: N = "IndexOutOfBoundsException"
final val FunctionXXL: N = "FunctionXXL"

final val Abs: N = "Abs"
final val And: N = "&&"
final val Div: N = "/"
final val Equals: N = "=="
final val Ge: N = ">="
final val Gt: N = ">"
final val Le: N = "<="
final val Lt: N = "<"
final val Max: N = "Max"
final val Min: N = "Min"
final val Minus: N = "-"
final val Mod: N = "%"
final val Negate: N = "Negate"
final val Not: N = "!"
final val NotEquals: N = "!="
final val Or: N = "||"
final val Plus: N = "+"
final val S: N = "S"
final val Times: N = "*"
final val ToString: N = "ToString"
final val Xor: N = "^"

final val ClassfileAnnotation: N = "ClassfileAnnotation"
final val ClassManifest: N = "ClassManifest"
final val Enum: N = "Enum"
Expand Down
8 changes: 6 additions & 2 deletions compiler/src/dotty/tools/dotc/core/TypeApplications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -371,12 +371,16 @@ class TypeApplications(val self: Type) extends AnyVal {
// just eta-reduction (ignoring variance annotations).
// See i2201*.scala for examples where more aggressive
// reduction would break type inference.
dealiased.paramRefs == dealiasedArgs
dealiased.paramRefs == dealiasedArgs ||
defn.isCompiletimeAppliedType(tyconBody.typeSymbol)
case _ => false
}
}
if ((dealiased eq stripped) || followAlias)
try dealiased.instantiate(args)
try {
val instantiated = dealiased.instantiate(args)
if (followAlias) instantiated.normalized else instantiated
}
catch { case ex: IndexOutOfBoundsException => AppliedType(self, args) }
else AppliedType(self, args)
}
Expand Down
18 changes: 15 additions & 3 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
compareLower(bounds(param2), tyconIsTypeRef = false)
case tycon2: TypeRef =>
isMatchingApply(tp1) ||
defn.isCompiletime_S(tycon2.symbol) && compareS(tp2, tp1, fromBelow = true) || {
defn.isCompiletimeAppliedType(tycon2.symbol) && compareCompiletimeAppliedType(tp2, tp1, fromBelow = true) || {
tycon2.info match {
case info2: TypeBounds =>
compareLower(info2, tyconIsTypeRef = true)
Expand Down Expand Up @@ -1005,7 +1005,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
case tycon1: TypeRef =>
val sym = tycon1.symbol
!sym.isClass && {
defn.isCompiletime_S(sym) && compareS(tp1, tp2, fromBelow = false) ||
defn.isCompiletimeAppliedType(sym) && compareCompiletimeAppliedType(tp1, tp2, fromBelow = false) ||
recur(tp1.superType, tp2) ||
tryLiftedToThis1
}
Expand All @@ -1015,7 +1015,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
false
}

/** Compare `tp` of form `S[arg]` with `other`, via ">:>` if fromBelow is true, "<:<" otherwise.
/** Compare `tp` of form `S[arg]` with `other`, via ">:>" if fromBelow is true, "<:<" otherwise.
* If `arg` is a Nat constant `n`, proceed with comparing `n + 1` and `other`.
* Otherwise, if `other` is a Nat constant `n`, proceed with comparing `arg` and `n - 1`.
*/
Expand All @@ -1037,6 +1037,18 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
case _ => false
}

/** Compare `tp` of form `tycon[...args]`, where `tycon` is a scala.compiletime type,
* with `other` via ">:>" if fromBelow is true, "<:<" otherwise.
* Delegates to compareS if `tycon` is scala.compiletime.S. Otherwise, constant folds if possible.
*/
def compareCompiletimeAppliedType(tp: AppliedType, other: Type, fromBelow: Boolean): Boolean = {
if (defn.isCompiletime_S(tp.tycon.typeSymbol)) compareS(tp, other, fromBelow)
else {
val folded = tp.tryCompiletimeConstantFold
if (fromBelow) recur(other, folded) else recur(folded, other)
}
}

/** Like tp1 <:< tp2, but returns false immediately if we know that
* the case was covered previously during subtyping.
*/
Expand Down
98 changes: 88 additions & 10 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3595,19 +3595,97 @@ object Types {
case _ =>
NoType
}
if (defn.isCompiletime_S(tycon.symbol) && args.length == 1)
trace(i"normalize S $this", typr, show = true) {
args.head.normalized match {
case ConstantType(Constant(n: Int)) if n >= 0 && n < Int.MaxValue =>
ConstantType(Constant(n + 1))
case none => tryMatchAlias
}
}
else tryMatchAlias

tryCompiletimeConstantFold.orElse(tryMatchAlias)

case _ =>
NoType
}

def tryCompiletimeConstantFold(implicit ctx: Context): Type = tycon match {
case tycon: TypeRef if defn.isCompiletimeAppliedType(tycon.symbol) =>
def constValue(tp: Type): Option[Any] = tp match {
case ConstantType(Constant(n)) => Some(n)
case _ => None
}

def boolValue(tp: Type): Option[Boolean] = tp match {
case ConstantType(Constant(n: Boolean)) => Some(n)
case _ => None
}

def intValue(tp: Type): Option[Int] = tp match {
case ConstantType(Constant(n: Int)) => Some(n)
case _ => None
}

def stringValue(tp: Type): Option[String] = tp match {
case ConstantType(Constant(n: String)) => Some(n)
case _ => None
}

def natValue(tp: Type): Option[Int] = intValue(tp).filter(n => n >= 0 && n < Int.MaxValue)

def constantFold1[T](extractor: Type => Option[T], op: T => Any): Option[Type] =
extractor(args.head.normalized).map(a => ConstantType(Constant(op(a))))

def constantFold2[T](extractor: Type => Option[T], op: (T, T) => Any): Option[Type] =
for {
a <- extractor(args.head.normalized)
b <- extractor(args.tail.head.normalized)
} yield ConstantType(Constant(op(a, b)))

trace(i"compiletime constant fold $this", typr, show = true) {
val name = tycon.symbol.name
val owner = tycon.symbol.owner
val nArgs = args.length
val constantType =
if (owner == defn.CompiletimePackageObject.moduleClass) name match {
case tpnme.S if nArgs == 1 => constantFold1(natValue, _ + 1)
case _ => None
} else if (owner == defn.CompiletimeOpsPackageObjectAny.moduleClass) name match {
case tpnme.Equals if nArgs == 2 => constantFold2(constValue, _ == _)
case tpnme.NotEquals if nArgs == 2 => constantFold2(constValue, _ != _)
case _ => None
} else if (owner == defn.CompiletimeOpsPackageObjectInt.moduleClass) name match {
case tpnme.Abs if nArgs == 1 => constantFold1(intValue, _.abs)
case tpnme.Negate if nArgs == 1 => constantFold1(intValue, x => -x)
case tpnme.ToString if nArgs == 1 => constantFold1(intValue, _.toString)
case tpnme.Plus if nArgs == 2 => constantFold2(intValue, _ + _)
case tpnme.Minus if nArgs == 2 => constantFold2(intValue, _ - _)
case tpnme.Times if nArgs == 2 => constantFold2(intValue, _ * _)
case tpnme.Div if nArgs == 2 => constantFold2(intValue, {
case (_, 0) => throw new TypeError("Division by 0")
case (a, b) => a / b
})
case tpnme.Mod if nArgs == 2 => constantFold2(intValue, {
case (_, 0) => throw new TypeError("Modulo by 0")
case (a, b) => a % b
})
case tpnme.Lt if nArgs == 2 => constantFold2(intValue, _ < _)
case tpnme.Gt if nArgs == 2 => constantFold2(intValue, _ > _)
case tpnme.Ge if nArgs == 2 => constantFold2(intValue, _ >= _)
case tpnme.Le if nArgs == 2 => constantFold2(intValue, _ <= _)
case tpnme.Min if nArgs == 2 => constantFold2(intValue, _ min _)
case tpnme.Max if nArgs == 2 => constantFold2(intValue, _ max _)
case _ => None
} else if (owner == defn.CompiletimeOpsPackageObjectString.moduleClass) name match {
case tpnme.Plus if nArgs == 2 => constantFold2(stringValue, _ + _)
case _ => None
} else if (owner == defn.CompiletimeOpsPackageObjectBoolean.moduleClass) name match {
case tpnme.Not if nArgs == 1 => constantFold1(boolValue, x => !x)
case tpnme.And if nArgs == 2 => constantFold2(boolValue, _ && _)
case tpnme.Or if nArgs == 2 => constantFold2(boolValue, _ || _)
case tpnme.Xor if nArgs == 2 => constantFold2(boolValue, _ ^ _)
case _ => None
} else None

constantType.getOrElse(NoType)
}

case _ => NoType
}

def lowerBound(implicit ctx: Context): Type = tycon.stripTypeVar match {
case tycon: TypeRef =>
tycon.info match {
Expand Down Expand Up @@ -3974,7 +4052,7 @@ object Types {
myReduced =
trace(i"reduce match type $this $hashCode", typr, show = true) {
try
typeComparer.matchCases(scrutinee, cases)(trackingCtx)
typeComparer.matchCases(scrutinee.normalized, cases)(trackingCtx)
catch {
case ex: Throwable =>
handleRecursive("reduce type ", i"$scrutinee match ...", ex)
Expand Down
54 changes: 51 additions & 3 deletions docs/docs/reference/metaprogramming/inline.md
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ val intTwo: 2 = natTwo

The `scala.compiletime` package contains helper definitions that provide support for compile time operations over values. They are described in the following.

#### `constValue`, `constValueOpt`, and the `S` combinator
### `constValue`, `constValueOpt`, and the `S` combinator

`constvalue` is a function that produces the constant value represented by a
type.
Expand All @@ -317,7 +317,7 @@ enabling us to handle situations where a value is not present. Note that `S` is
the type of the successor of some singleton type. For example the type `S[1]` is
the singleton type `2`.

#### `erasedValue`
### `erasedValue`

So far we have seen inline methods that take terms (tuples and integers) as
parameters. What if we want to base case distinctions on types instead? For
Expand Down Expand Up @@ -381,7 +381,7 @@ final val two = toIntT[Succ[Succ[Zero.type]]]
behavior. Since `toInt` performs static checks over the static type of `N` we
can safely use it to scrutinize its return type (`S[S[Z]]` in this case).

#### `error`
### `error`

The `error` method is used to produce user-defined compile errors during inline expansion.
It has the following signature:
Expand Down Expand Up @@ -411,6 +411,54 @@ inline def fail(p1: => Any) = {
fail(identity("foo")) // error: failed on: identity("foo")
```

### The `scala.compiletime.ops` package

The `scala.compiletime.ops` package contains types that provide support for
primitive operations on singleton types. For example,
`scala.compiletime.ops.int.*` provides support for multiplying two singleton
`Int` types, and `scala.compiletime.ops.boolean.&&` for the conjunction of two
`Boolean` types. When all arguments to a type in `scala.compiletime.ops` are
singleton types, the compiler can evaluate the result of the operation.

```scala
import scala.compiletime.ops.int._
import scala.compiletime.ops.boolean._

val conjunction: true && true = true
val multiplication: 3 * 5 = 15
```

Many of these singleton operation types are meant to be used infix (as in [SLS §
3.2.8](https://www.scala-lang.org/files/archive/spec/2.12/03-types.html#infix-types)),
and are annotated with [`@infix`](scala.annotation.infix) accordingly.

Since type aliases have the same precedence rules as their term-level
equivalents, the operations compose with the expected precedence rules:

```scala
import scala.compiletime.ops.int._
val x: 1 + 2 * 3 = 7
```

The operation types are located in packages named after the type of the
left-hand side parameter: for instance, `scala.compiletime.int.+` represents
addition of two numbers, while `scala.compiletime.string.+` represents string
concatenation. To use both and distinguish the two types from each other, a
match type can dispatch to the correct implementation:

```scala
import scala.compiletime.ops._
import scala.annotation.infix

@infix type +[X <: Int | String, Y <: Int | String] = (X, Y) match {
case (Int, Int) => int.+[X, Y]
case (String, String) => string.+[X, Y]
}

val concat: "a" + "b" = "ab"
val addition: 1 + 1 = 2
```

## Summoning Implicits Selectively

It is foreseen that many areas of typelevel programming can be done with rewrite
Expand Down
Loading