Skip to content

Commit

Permalink
Arbitrary arity tuple (#435)
Browse files Browse the repository at this point in the history
* Add arbitrary arity tuple Quat

* Decode arbitrary arity tuple

* Use runtime.Tuples to construct Tuple

* Fat finger fix

* Support beta reduction with arbitrary tuples

* Support empty tuple decoding

* Parse arbitrary tuples construction

* Remove commented out code

---------

Co-authored-by: Alexander Ioffe <deusaquilus@gmail.com>
  • Loading branch information
Primetalk and deusaquilus committed May 22, 2024
1 parent 1753850 commit 4a336ec
Show file tree
Hide file tree
Showing 5 changed files with 304 additions and 17 deletions.
63 changes: 46 additions & 17 deletions quill-sql/src/main/scala/io/getquill/generic/GenericDecoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,31 @@ object GenericDecoder {
}
} // end flatten

// similar to flatten but without labels
@tailrec
def values[ResultRow: Type, Session: Type, Types: Type](
index: Int,
baseIndex: Expr[Int],
resultRow: Expr[ResultRow],
session: Expr[Session]
)(accum: List[FlattenData] = List())(using Quotes): List[FlattenData] = {
import quotes.reflect.{Term => QTerm, _}

Type.of[Types] match {
case '[tpe *: types] if Expr.summon[GenericDecoder[ResultRow, Session, tpe, DecodingType.Specific]].isEmpty =>
val result = decode[tpe, ResultRow, Session](index, baseIndex, resultRow, session)
val nextIndex = result.index + 1
values[ResultRow, Session, types](nextIndex, baseIndex, resultRow, session)(result +: accum)
case '[tpe *: types] =>
val result = decode[tpe, ResultRow, Session](index, baseIndex, resultRow, session, None)
val nextIndex = index + 1
values[ResultRow, Session, types](nextIndex, baseIndex, resultRow, session)(result +: accum)
case '[EmptyTuple] => accum

case typesTup => report.throwError("Cannot Derive Product during Values extraction:\n" + typesTup)
}
} // end values

def decodeOptional[T: Type, ResultRow: Type, Session: Type](index: Int, baseIndex: Expr[Int], resultRow: Expr[ResultRow], session: Expr[Session])(using Quotes): FlattenData = {
import quotes.reflect._
// Try to summon a specific optional from the context, this may not exist since
Expand Down Expand Up @@ -163,9 +188,9 @@ object GenericDecoder {
// List((new Name(Decoder("Joe") || Decoder("Bloggs")), Decoder(123))
// This is what needs to be fed into the constructor of the outer-entity i.e.
// new Person((new Name(Decoder("Joe") || Decoder("Bloggs")), Decoder(123))
val productElments = flattenData.map(_.decodedExpr)
val productElements = flattenData.map(_.decodedExpr)
// actually doing the construction i.e. `new Person(...)`
val constructed = ConstructDecoded[T](types, productElments, m)
val constructed = ConstructDecoded[T](types, productElements, m)

// E.g. for Person("Joe", 123) the List(q"!nullChecker(0,row)", q"!nullChecker(1,row)") columns
// that eventually turn into List(!NullChecker("Joe"), !NullChecker(123)) columns.
Expand All @@ -192,6 +217,11 @@ object GenericDecoder {
TypeRepr.of[T] <:< TypeRepr.of[Option[Any]]
}

private def isTuple[T: Type](using Quotes) = {
import quotes.reflect._
TypeRepr.of[T] <:< TypeRepr.of[Tuple]
}

private def isBuiltInType[T: Type](using Quotes) = {
import quotes.reflect._
isOption[T] || (TypeRepr.of[T] <:< TypeRepr.of[Seq[_]])
Expand All @@ -207,6 +237,16 @@ object GenericDecoder {
case '[Option[tpe]] =>
decodeOptional[tpe, ResultRow, Session](index, baseIndex, resultRow, session)
}
} else if (isTuple[T]) {
if (TypeRepr.of[T] <:< TypeRepr.of[EmptyTuple]) {
FlattenData(Type.of[T], '{ EmptyTuple }, '{ false }, index)
} else {
val flattenData = values[ResultRow, Session, T](index, baseIndex, resultRow, session)().reverse
val elementTerms = flattenData.map(_.decodedExpr) // expressions that represent values for tuple elements
val constructed = '{ scala.runtime.Tuples.fromArray(${ Varargs(elementTerms) }.toArray[Any](Predef.summon[ClassTag[Any]]).asInstanceOf[Array[Object]]).asInstanceOf[T] }
val nullChecks = flattenData.map(_._3).reduce((a, b) => '{ $a || $b })
FlattenData(Type.of[T], constructed, nullChecks, flattenData.last.index)
}
} else {
// specifically if there is a decoder found, allow optional override of the index via a resolver
val decoderIndex = overriddenIndex.getOrElse(elementIndex)
Expand Down Expand Up @@ -341,21 +381,10 @@ object ConstructDecoded {
val tpe = TypeRepr.of[T]
val constructor = TypeRepr.of[T].typeSymbol.primaryConstructor
// If we are a tuple, we can easily construct it
if (tpe <:< TypeRepr.of[Tuple]) {
val construct =
Apply(
TypeApply(
Select(New(TypeTree.of[T]), constructor),
types.map { tpe =>
tpe match {
case '[tt] => TypeTree.of[tt]
}
}
),
terms.map(_.asTerm)
)
// println(s"=========== Create from Tuple Constructor ${Format.Expr(construct.asExprOf[T])} ===========")
construct.asExprOf[T]
if (tpe <:< TypeRepr.of[EmptyTuple]) {
'{EmptyTuple}
} else if (tpe <:< TypeRepr.of[Tuple]) {
'{scala.runtime.Tuples.fromIArray(IArray(${Varargs(terms)})).asInstanceOf[T]}
// If we are a case class with no generic parameters, we can easily construct it
} else if (tpe.classSymbol.exists(_.flags.is(Flags.Case)) && !constructor.paramSymss.exists(_.exists(_.isTypeParam))) {
val construct =
Expand Down
108 changes: 108 additions & 0 deletions quill-sql/src/main/scala/io/getquill/parser/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ trait ParserLibrary extends ParserFactory {
protected def functionParser(using Quotes, TranspileConfig) = ParserChain.attempt(FunctionParser(_))
protected def functionApplyParser(using Quotes, TranspileConfig) = ParserChain.attempt(FunctionApplyParser(_))
protected def valParser(using Quotes, TranspileConfig) = ParserChain.attempt(ValParser(_))
protected def arbitraryTupleParser(using Quotes, TranspileConfig) = ParserChain.attempt(ArbitraryTupleBlockParser(_))
protected def blockParser(using Quotes, TranspileConfig) = ParserChain.attempt(BlockParser(_))
protected def extrasParser(using Quotes, TranspileConfig) = ParserChain.attempt(ExtrasParser(_))
protected def operationsParser(using Quotes, TranspileConfig) = ParserChain.attempt(OperationsParser(_))
Expand Down Expand Up @@ -88,6 +89,7 @@ trait ParserLibrary extends ParserFactory {
.orElse(functionParser) // decided to have it be it's own parser unlike Quill3
.orElse(patMatchParser)
.orElse(valParser)
.orElse(arbitraryTupleParser)
.orElse(blockParser)
.orElse(operationsParser)
.orElse(extrasParser)
Expand Down Expand Up @@ -153,6 +155,112 @@ class ValParser(val rootParse: Parser)(using Quotes, TranspileConfig)
case Unseal(ValDefTerm(ast)) => ast
}
}
/**
* Matches `runtime.Tuples.cons(head,tail)`.
*/
object TupleCons {
def unapply(using Quotes)(t: quotes.reflect.Term): Option[(quotes.reflect.Term, quotes.reflect.Term)] = {
import quotes.reflect.*
t match {
case Apply(Select(Select(Ident("runtime"), "Tuples"), "cons"), List(head, tail)) =>
Some((head, tail))
case _ =>
None
}
}
}

/**
* Matches inner.asInstanceOf[T]: T
*/
object AsInstanceOf {
def unapply(using Quotes)(term: quotes.reflect.Term): Option[quotes.reflect.Term] = {
import quotes.reflect._
term match {
case TypeApply(Select(inner, "asInstanceOf"), _) => Some(inner)
case _ => None
}
}
}

/**
* Matches an inlined call to `Tuple.*:`:
* {{{
* {
* val Tuple_this: scala.Tuple$package.EmptyTuple.type = scala.Tuple$package.EmptyTuple
*
* (scala.runtime.Tuples.cons(i, Tuple_this).asInstanceOf[scala.*:[scala.Int, scala.Tuple$package.EmptyTuple.type]]: scala.*:[scala.Int, scala.Tuple$package.EmptyTuple])
* }
* }}}
*/
object ArbitraryTupleConstructionInlined {
def unapply(using Quotes)(t: quotes.reflect.Term): Option[(quotes.reflect.Term, quotes.reflect.Term)] = {
import quotes.reflect.{Ident => TIdent, *}
t match {
case
Inlined(
_,
List(ValDef("Tuple_this", _, Some(prevTuple))),
Typed(AsInstanceOf(TupleCons(head, TIdent("Tuple_this"))), _)
) =>
Some((head, prevTuple))
case _ => None
}
}
}

/**
* Parses a few cases of arbitrary tuples.
*
* Scala 3 produces a few different trees for arbitrary tuples. Method `*:` is marked as inline.
* Under the hood it actually invokes `Tuples.cons` function:
* {{{
* inline def *: [H, This >: this.type <: Tuple] (x: H): H *: This =
* runtime.Tuples.cons(x, this).asInstanceOf[H *: This]
* }}}
* So, at least we have to match Tuples.cons.
* However, it's not the only variation. Scala also produces a block with intermediate val `Tuple_this` definitions:
* {{{
* {
* val Tuple_this: scala.Tuple$package.EmptyTuple.type = scala.Tuple$package.EmptyTuple
* val `Tuple_this₂`: scala.*:[java.lang.String, scala.Tuple$package.EmptyTuple] = (scala.runtime.Tuples.cons("", Tuple_this).asInstanceOf[scala.*:[java.lang.String, scala.Tuple$package.EmptyTuple.type]]: scala.*:[java.lang.String, scala.Tuple$package.EmptyTuple])
*
* (scala.runtime.Tuples.cons(1, `Tuple_this₂`).asInstanceOf[scala.*:[scala.Int, scala.*:[java.lang.String, scala.Tuple$package.EmptyTuple]]]: scala.*:[scala.Int, scala.*:[java.lang.String, scala.Tuple$package.EmptyTuple]])
* }
* }}}
*/
class ArbitraryTupleBlockParser(val rootParse: Parser)(using Quotes, TranspileConfig)
extends Parser(rootParse)
with PatternMatchingValues {

import quotes.reflect.{Block => TBlock, Ident => TIdent, _}

def attempt = {
case '{EmptyTuple} =>
ast.Tuple(List())
case '{$a *: EmptyTuple} =>
val aAst = rootParse(a)
ast.Tuple(List(aAst))
case inlined@Unseal(ArbitraryTupleConstructionInlined(singleValue, prevTuple)) =>
val headAst = rootParse(singleValue.asExpr)
val prevTupleAst = rootParse(prevTuple.asExpr)
prevTupleAst match {
case ast.Tuple(lst) => ast.Tuple(headAst :: lst)
case _ =>
throw IllegalArgumentException(s"Unexpected tuple ast ${prevTupleAst}")
}
case block@Unseal(TBlock(parts, ArbitraryTupleConstructionInlined(head,Typed(AsInstanceOf(TupleCons(head2,TIdent("Tuple_this"))), _)) )) if (parts.length > 0) =>
val headAst = rootParse(head.asExpr)
val head2Ast = rootParse(head2.asExpr)
val partsAsts = headAst :: head2Ast :: parts.reverse.flatMap{
case ValDef("Tuple_this", tpe, Some(TIdent("EmptyTuple"))) => List()
case ValDef("Tuple_this", tpe, Some(Typed(AsInstanceOf(TupleCons(next,TIdent("Tuple_this"))), _))) => List(rootParse(next.asExpr))
case ValDef("Tuple_this", tpe, Some(unknown)) =>
throw IllegalArgumentException(s"Unexpected Tuple_this = ${unknown.show}")
}
Tuple(partsAsts)
}
}

class BlockParser(val rootParse: Parser)(using Quotes, TranspileConfig)
extends Parser(rootParse)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,12 @@ object ParserHelpers {
binds.zipWithIndex.flatMap { case (bind, idx) =>
tupleBindsPath(bind, path :+ s"_${idx + 1}")
}
case Unapply(TypeApply(Select(TIdent("*:"), "unapply"), types), implicits, List(h, t)) =>
List(
tupleBindsPath(h, path :+ s"head"),
tupleBindsPath(h, path :+ s"tail")
)
.flatten
// If it's a "case _ => ..." then that just translates into the body expression so we don't
// need a clause to beta reduction over the entire partial-function
case TIdent("_") =>
Expand Down
22 changes: 22 additions & 0 deletions quill-sql/src/main/scala/io/getquill/quat/QuatMaking.scala
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,25 @@ trait QuatMakingBase {
None
}

object ArbitraryArityTupleType {
def unapply(using Quotes)(tpe: quotes.reflect.TypeRepr): Option[List[quotes.reflect.TypeRepr]] =
if (tpe.is[Tuple])
Some(tupleParts(tpe))
else
None

@tailrec
def tupleParts(using Quotes)(tpe: quotes.reflect.TypeRepr, accum: List[quotes.reflect.TypeRepr] = Nil): List[quotes.reflect.TypeRepr] =
tpe.asType match {
case '[h *: t] =>
val htpe = quotes.reflect.TypeRepr.of[h]
val ttpe = quotes.reflect.TypeRepr.of[t]
tupleParts(ttpe, htpe :: accum)
case '[EmptyTuple] =>
accum.reverse
}
}

object OptionType {
def unapply(using Quotes)(tpe: quotes.reflect.TypeRepr): Option[quotes.reflect.TypeRepr] = {
import quotes.reflect._
Expand Down Expand Up @@ -384,6 +403,9 @@ trait QuatMakingBase {
case CaseClassBaseType(name, fields) if !existsEncoderFor(tpe) || tpe <:< TypeRepr.of[Udt] =>
Quat.Product(name, fields.map { case (fieldName, fieldType) => (fieldName, parseType(fieldType)) })

case ArbitraryArityTupleType(tupleParts) =>
Quat.Product("Tuple", tupleParts.zipWithIndex.map { case (fieldType, idx) => (s"_${idx + 1}", parseType(fieldType)) })

// If we are already inside a bounded type, treat an arbitrary type as a interface list
case ArbitraryBaseType(name, fields) if (boundedInterfaceType) =>
Quat.Product(name, fields.map { case (fieldName, fieldType) => (fieldName, parseType(fieldType)) })
Expand Down
122 changes: 122 additions & 0 deletions quill-sql/src/test/scala/io/getquill/ArbitraryTupleSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package io.getquill

import io.getquill.context.ExecutionType.Static
import io.getquill.context.mirror.{MirrorSession, Row}
import io.getquill.generic.TupleMember

class ArbitraryTupleSpec extends Spec {

val ctx = new MirrorContext(PostgresDialect, Literal)
import ctx._

type MyRow1 = (Int, String)
type MyRow2 = Int *: String *: EmptyTuple

inline def myRow1Query = quote {
querySchema[MyRow1]("my_table", t => t._1 -> "int_field", t => t._2 -> "string_field")
}

inline def myRow2Query = quote {
querySchema[MyRow2]("my_table", t => t._1 -> "int_field", t => t._2 -> "string_field")
}

"ordinary tuple" in {
val result = ctx.run(myRow1Query)

result.string mustEqual "SELECT x.int_field, x.string_field FROM my_table x"
result.extractor(Row(123, "St"), MirrorSession.default) mustEqual
(123, "St")
}

"ordinary tuple swap" in {

transparent inline def swapped: Quoted[EntityQuery[(String, Int)]] = quote {
myRow1Query.map {
case (i, s) => (s, i)
}
}

val result = ctx.run(swapped)

result.string mustEqual "SELECT x$1.string_field AS _1, x$1.int_field AS _2 FROM my_table x$1"
require(result.extractor(Row("St", 123), MirrorSession.default) == ("St", 123))
}

"arbitrary long tuple" in {
val result = ctx.run(myRow2Query)

result.extractor(Row(123, "St"), MirrorSession.default) mustEqual
(123, "St")
}

"get field of arbitrary long tuple" in {
inline def g = quote{
myRow2Query.map{
case h *: tail => h
}
}
val result = ctx.run(g)

result.extractor(Row(123, "St"), MirrorSession.default) mustEqual
(123)
}

"decode empty tuple" in {
inline def g = quote {
myRow2Query.map{
case (_, _) => EmptyTuple
}
}

val result = ctx.run(g)

result.extractor(Row(123, "St"), MirrorSession.default) mustEqual EmptyTuple
}

"construct tuple1" in {
inline def g = quote {
myRow1Query.map {
case (i, s) => i *: EmptyTuple
}
}

val result = ctx.run(g)

result.string mustEqual "SELECT x$1.int_field AS _1 FROM my_table x$1"
result.extractor(Row(123, "St"), MirrorSession.default) mustEqual
Tuple1(123)
}

"construct arbitrary tuple" in {
inline def g = quote {
myRow1Query.map {
case (i, s) => s *: i *: EmptyTuple
}
}
val result = ctx.run(g)

result.string mustEqual "SELECT x$1.string_field AS _1, x$1.int_field AS _2 FROM my_table x$1"
result.extractor(Row("St", 123), MirrorSession.default) mustEqual ("St", 123)

}

"constant arbitrary tuple" in {
inline def g = quote {
123 *: "St" *: true *: (3.14 *: EmptyTuple)
}
val result = ctx.run(g)
result.string mustEqual "SELECT 123 AS _1, 'St' AS _2, true AS _3, 3.14 AS _4"
result.info.executionType mustEqual Static
result.extractor(Row(123, "St", true, 3.14), MirrorSession.default) mustEqual (123, "St", true, 3.14)
}

"constant arbitrary tuple 1" in {
inline def g = quote {
(3.14 *: EmptyTuple)
}
val result = ctx.run(g)
result.string mustEqual "SELECT 3.14 AS _1"
result.info.executionType mustEqual Static
result.extractor(Row(3.14), MirrorSession.default) mustEqual Tuple1(3.14)
}
}

0 comments on commit 4a336ec

Please sign in to comment.