Skip to content

Commit

Permalink
Improve type inference for literal named tuples (#20497)
Browse files Browse the repository at this point in the history
Adds a new `NamedTuple.build` method which fixes the types of the labels
first, as suggested in
#20456 (comment)

It requires `language.experimental.clauseInterleaving` language import.

Keeps `withNames` as a friendlier option for end-users

fixes #20456
  • Loading branch information
odersky committed May 30, 2024
2 parents 3c9d985 + bf0cd3c commit 01b404f
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 9 deletions.
10 changes: 7 additions & 3 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1596,9 +1596,13 @@ object desugar {
if ctx.mode.is(Mode.Type) then
AppliedTypeTree(ref(defn.NamedTupleTypeRef), namesTuple :: tup :: Nil)
else
TypeApply(
Apply(Select(ref(defn.NamedTupleModule), nme.withNames), tup),
namesTuple :: Nil)
Apply(
Apply(
TypeApply(
Select(ref(defn.NamedTupleModule), nme.build), // NamedTuple.build
namesTuple :: Nil), // ++ [(names...)]
Nil), // ++ ()
tup :: Nil) // .++ ((values...))

/** When desugaring a list pattern arguments `elems` adapt them and the
* expected type `pt` to each other. This means:
Expand Down
7 changes: 6 additions & 1 deletion library/src/scala/NamedTuple.scala
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
package scala
import scala.language.experimental.clauseInterleaving
import annotation.experimental
import compiletime.ops.boolean.*

Expand All @@ -19,6 +20,11 @@ object NamedTuple:

def unapply[N <: Tuple, V <: Tuple](x: NamedTuple[N, V]): Some[V] = Some(x)

/** A named tuple expression will desugar to a call to `build`. For instance,
* `(name = "Lyra", age = 23)` will desugar to `build[("name", "age")]()(("Lyra", 23))`.
*/
inline def build[N <: Tuple]()[V <: Tuple](x: V): NamedTuple[N, V] = x

extension [V <: Tuple](x: V)
inline def withNames[N <: Tuple]: NamedTuple[N, V] = x

Expand Down Expand Up @@ -214,4 +220,3 @@ object NamedTupleDecomposition:
/** The value types of a named tuple represented as a regular tuple. */
type DropNames[NT <: AnyNamedTuple] <: Tuple = NT match
case NamedTuple[_, x] => x

121 changes: 121 additions & 0 deletions tests/pos/named-tuples-ops-mirror.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import language.experimental.namedTuples
import NamedTuple.*

@FailsWith[HttpError]
trait GreetService derives HttpService:
@HttpInfo("GET", "/greet/{name}")
def greet(@HttpPath name: String): String
@HttpInfo("POST", "/greet/{name}")
def setGreeting(@HttpPath name: String, @HttpBody greeting: String): Unit

@main def Test =

val e = HttpService.endpoints[GreetService]

println(e.greet.describe)
println(e.setGreeting.describe)

// Type-safe server logic, driven by the ops-mirror,
// requires named tuple with same labels in the same order,
// and function that matches the required signature.
val logic = e.serverLogic:
(
greet = (name) => Right("Hello, " + name),
setGreeting = (name, greeting) => Right(())
)

val server = ServerBuilder()
.handleAll(logic)
.create(port = 8080)

sys.addShutdownHook(server.close())

end Test

// IMPLEMENTATION DETAILS FOLLOW

/** Assume existence of macro to generate this */
given (OpsMirror.Of[GreetService] {
type MirroredType = GreetService
type OperationLabels = ("greet", "setGreeting")
type Operations = (
OpsMirror.Operation { type InputTypes = (String *: EmptyTuple); type OutputType = String; type ErrorType = HttpError },
OpsMirror.Operation { type InputTypes = (String *: String *: EmptyTuple); type OutputType = Unit; type ErrorType = HttpError }
)
}) = new OpsMirror:
type MirroredType = GreetService
type OperationLabels = ("greet", "setGreeting")
type Operations = (
OpsMirror.Operation { type InputTypes = (String *: EmptyTuple); type OutputType = String; type ErrorType = HttpError },
OpsMirror.Operation { type InputTypes = (String *: String *: EmptyTuple); type OutputType = Unit; type ErrorType = HttpError }
)

object OpsMirror:
type Of[T] = OpsMirror { type MirroredType = T }

type Operation_I[I <: Tuple] = Operation { type InputTypes = I }
type Operation_O[O] = Operation { type OutputType = O }
type Operation_E[E] = Operation { type ErrorType = E }

trait Operation:
type InputTypes <: Tuple
type OutputType
type ErrorType

trait OpsMirror:
type MirroredType
type OperationLabels <: Tuple
type Operations <: Tuple

trait HttpService[T]:
def route(str: String): Route
trait Route

type Func[I <: Tuple, O, E] = I match
case EmptyTuple => Either[E, O]
case t *: EmptyTuple => t => Either[E, O]
case t *: u *: EmptyTuple => (t, u) => Either[E, O]

type ToFunc[T] = T match
case HttpService.Endpoint[i, o, e] => Func[i, o, e]

final class FailsWith[E] extends scala.annotation.Annotation
final class HttpInfo(method: String, route: String) extends scala.annotation.Annotation
final class HttpBody() extends scala.annotation.Annotation
final class HttpPath() extends scala.annotation.Annotation

sealed trait HttpError

object HttpService:
opaque type Endpoint[I <: Tuple, O, E] = Route

extension [I <: Tuple, O, E](e: Endpoint[I, O, E])
def describe: String = ??? // some thing that looks inside the Route to debug it

type ToEndpoints[Ops <: Tuple] <: Tuple = Ops match
case EmptyTuple => EmptyTuple
case op *: ops => (op, op, op) match
case (OpsMirror.Operation_I[i]) *: (OpsMirror.Operation_O[o]) *: (OpsMirror.Operation_E[e]) *: _ =>
Endpoint[i, o, e] *: ToEndpoints[ops]

trait Handler

class Endpoints[T](val model: HttpService[T]) extends Selectable:
type Fields <: AnyNamedTuple
def selectDynamic(name: String): Route = model.route(name)

def serverLogic(funcs: NamedTuple[Names[Fields], Tuple.Map[DropNames[Fields], ToFunc]]): List[Handler] = ???

def derived[T](using OpsMirror.Of[T]): HttpService[T] = ??? // inline method to create routes

def endpoints[T](using model: HttpService[T], m: OpsMirror.Of[T]): Endpoints[T] {
type Fields = NamedTuple[m.OperationLabels, ToEndpoints[m.Operations]]
} =
new Endpoints(model) { type Fields = NamedTuple[m.OperationLabels, ToEndpoints[m.Operations]] }

class ServerBuilder():
def handleAll(hs: List[HttpService.Handler]): this.type = this
def create(port: Int): Server = Server()

class Server():
def close(): Unit = ()
1 change: 1 addition & 0 deletions tests/run/named-tuples.check
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ Bob is younger than Bill
Bob is younger than Lucy
Bill is younger than Lucy
(((Lausanne,Pully),Preverenges),((1003,1009),1028))
118
15 changes: 10 additions & 5 deletions tests/run/named-tuples.scala
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,20 @@ val _: CombinedInfo = bob ++ addr
val addr4 = addr3.zip("Preverenges", 1028)
println(addr4)

val reducer: (map: Person => Int, reduce: (Int, Int) => Int) =
(map = _.age, reduce = _ + _)

extension [T](xs: List[T])
def mapReduce[U](reducer: (map: T => U, reduce: (U, U) => U)): U =
xs.map(reducer.map).reduce(reducer.reduce)

val totalAge = persons.mapReduce(reducer)
println(totalAge)

// testing conversions
object Conv:

val p: (String, Int) = bob
def f22(x: (String, Int)) = x._1
def f22(x: String) = x
f22(bob)





0 comments on commit 01b404f

Please sign in to comment.