diff --git a/quill-sql/src/main/scala/io/getquill/generic/GenericDecoder.scala b/quill-sql/src/main/scala/io/getquill/generic/GenericDecoder.scala index 39370868..b73e0f5d 100644 --- a/quill-sql/src/main/scala/io/getquill/generic/GenericDecoder.scala +++ b/quill-sql/src/main/scala/io/getquill/generic/GenericDecoder.scala @@ -125,6 +125,31 @@ object GenericDecoder { } } // end flatten + // similar to flatten but without labels + @tailrec + def values[ResultRow: Type, Session: Type, Types: Type]( + index: Int, + baseIndex: Expr[Int], + resultRow: Expr[ResultRow], + session: Expr[Session] + )(accum: List[FlattenData] = List())(using Quotes): List[FlattenData] = { + import quotes.reflect.{Term => QTerm, _} + + Type.of[Types] match { + case '[tpe *: types] if Expr.summon[GenericDecoder[ResultRow, Session, tpe, DecodingType.Specific]].isEmpty => + val result = decode[tpe, ResultRow, Session](index, baseIndex, resultRow, session) + val nextIndex = result.index + 1 + values[ResultRow, Session, types](nextIndex, baseIndex, resultRow, session)(result +: accum) + case '[tpe *: types] => + val result = decode[tpe, ResultRow, Session](index, baseIndex, resultRow, session, None) + val nextIndex = index + 1 + values[ResultRow, Session, types](nextIndex, baseIndex, resultRow, session)(result +: accum) + case '[EmptyTuple] => accum + + case typesTup => report.throwError("Cannot Derive Product during Values extraction:\n" + typesTup) + } + } // end values + def decodeOptional[T: Type, ResultRow: Type, Session: Type](index: Int, baseIndex: Expr[Int], resultRow: Expr[ResultRow], session: Expr[Session])(using Quotes): FlattenData = { import quotes.reflect._ // Try to summon a specific optional from the context, this may not exist since @@ -163,9 +188,9 @@ object GenericDecoder { // List((new Name(Decoder("Joe") || Decoder("Bloggs")), Decoder(123)) // This is what needs to be fed into the constructor of the outer-entity i.e. // new Person((new Name(Decoder("Joe") || Decoder("Bloggs")), Decoder(123)) - val productElments = flattenData.map(_.decodedExpr) + val productElements = flattenData.map(_.decodedExpr) // actually doing the construction i.e. `new Person(...)` - val constructed = ConstructDecoded[T](types, productElments, m) + val constructed = ConstructDecoded[T](types, productElements, m) // E.g. for Person("Joe", 123) the List(q"!nullChecker(0,row)", q"!nullChecker(1,row)") columns // that eventually turn into List(!NullChecker("Joe"), !NullChecker(123)) columns. @@ -192,6 +217,11 @@ object GenericDecoder { TypeRepr.of[T] <:< TypeRepr.of[Option[Any]] } + private def isTuple[T: Type](using Quotes) = { + import quotes.reflect._ + TypeRepr.of[T] <:< TypeRepr.of[Tuple] + } + private def isBuiltInType[T: Type](using Quotes) = { import quotes.reflect._ isOption[T] || (TypeRepr.of[T] <:< TypeRepr.of[Seq[_]]) @@ -207,6 +237,16 @@ object GenericDecoder { case '[Option[tpe]] => decodeOptional[tpe, ResultRow, Session](index, baseIndex, resultRow, session) } + } else if (isTuple[T]) { + if (TypeRepr.of[T] <:< TypeRepr.of[EmptyTuple]) { + FlattenData(Type.of[T], '{ EmptyTuple }, '{ false }, index) + } else { + val flattenData = values[ResultRow, Session, T](index, baseIndex, resultRow, session)().reverse + val elementTerms = flattenData.map(_.decodedExpr) // expressions that represent values for tuple elements + val constructed = '{ scala.runtime.Tuples.fromArray(${ Varargs(elementTerms) }.toArray[Any](Predef.summon[ClassTag[Any]]).asInstanceOf[Array[Object]]).asInstanceOf[T] } + val nullChecks = flattenData.map(_._3).reduce((a, b) => '{ $a || $b }) + FlattenData(Type.of[T], constructed, nullChecks, flattenData.last.index) + } } else { // specifically if there is a decoder found, allow optional override of the index via a resolver val decoderIndex = overriddenIndex.getOrElse(elementIndex) @@ -341,21 +381,10 @@ object ConstructDecoded { val tpe = TypeRepr.of[T] val constructor = TypeRepr.of[T].typeSymbol.primaryConstructor // If we are a tuple, we can easily construct it - if (tpe <:< TypeRepr.of[Tuple]) { - val construct = - Apply( - TypeApply( - Select(New(TypeTree.of[T]), constructor), - types.map { tpe => - tpe match { - case '[tt] => TypeTree.of[tt] - } - } - ), - terms.map(_.asTerm) - ) - // println(s"=========== Create from Tuple Constructor ${Format.Expr(construct.asExprOf[T])} ===========") - construct.asExprOf[T] + if (tpe <:< TypeRepr.of[EmptyTuple]) { + '{EmptyTuple} + } else if (tpe <:< TypeRepr.of[Tuple]) { + '{scala.runtime.Tuples.fromIArray(IArray(${Varargs(terms)})).asInstanceOf[T]} // If we are a case class with no generic parameters, we can easily construct it } else if (tpe.classSymbol.exists(_.flags.is(Flags.Case)) && !constructor.paramSymss.exists(_.exists(_.isTypeParam))) { val construct = diff --git a/quill-sql/src/main/scala/io/getquill/parser/Parser.scala b/quill-sql/src/main/scala/io/getquill/parser/Parser.scala index f120fe48..588c2fd1 100644 --- a/quill-sql/src/main/scala/io/getquill/parser/Parser.scala +++ b/quill-sql/src/main/scala/io/getquill/parser/Parser.scala @@ -53,6 +53,7 @@ trait ParserLibrary extends ParserFactory { protected def functionParser(using Quotes, TranspileConfig) = ParserChain.attempt(FunctionParser(_)) protected def functionApplyParser(using Quotes, TranspileConfig) = ParserChain.attempt(FunctionApplyParser(_)) protected def valParser(using Quotes, TranspileConfig) = ParserChain.attempt(ValParser(_)) + protected def arbitraryTupleParser(using Quotes, TranspileConfig) = ParserChain.attempt(ArbitraryTupleBlockParser(_)) protected def blockParser(using Quotes, TranspileConfig) = ParserChain.attempt(BlockParser(_)) protected def extrasParser(using Quotes, TranspileConfig) = ParserChain.attempt(ExtrasParser(_)) protected def operationsParser(using Quotes, TranspileConfig) = ParserChain.attempt(OperationsParser(_)) @@ -88,6 +89,7 @@ trait ParserLibrary extends ParserFactory { .orElse(functionParser) // decided to have it be it's own parser unlike Quill3 .orElse(patMatchParser) .orElse(valParser) + .orElse(arbitraryTupleParser) .orElse(blockParser) .orElse(operationsParser) .orElse(extrasParser) @@ -153,6 +155,112 @@ class ValParser(val rootParse: Parser)(using Quotes, TranspileConfig) case Unseal(ValDefTerm(ast)) => ast } } +/** + * Matches `runtime.Tuples.cons(head,tail)`. + */ +object TupleCons { + def unapply(using Quotes)(t: quotes.reflect.Term): Option[(quotes.reflect.Term, quotes.reflect.Term)] = { + import quotes.reflect.* + t match { + case Apply(Select(Select(Ident("runtime"), "Tuples"), "cons"), List(head, tail)) => + Some((head, tail)) + case _ => + None + } + } +} + +/** + * Matches inner.asInstanceOf[T]: T + */ +object AsInstanceOf { + def unapply(using Quotes)(term: quotes.reflect.Term): Option[quotes.reflect.Term] = { + import quotes.reflect._ + term match { + case TypeApply(Select(inner, "asInstanceOf"), _) => Some(inner) + case _ => None + } + } +} + +/** + * Matches an inlined call to `Tuple.*:`: + * {{{ + * { + * val Tuple_this: scala.Tuple$package.EmptyTuple.type = scala.Tuple$package.EmptyTuple + * + * (scala.runtime.Tuples.cons(i, Tuple_this).asInstanceOf[scala.*:[scala.Int, scala.Tuple$package.EmptyTuple.type]]: scala.*:[scala.Int, scala.Tuple$package.EmptyTuple]) + * } + * }}} + */ +object ArbitraryTupleConstructionInlined { + def unapply(using Quotes)(t: quotes.reflect.Term): Option[(quotes.reflect.Term, quotes.reflect.Term)] = { + import quotes.reflect.{Ident => TIdent, *} + t match { + case + Inlined( + _, + List(ValDef("Tuple_this", _, Some(prevTuple))), + Typed(AsInstanceOf(TupleCons(head, TIdent("Tuple_this"))), _) + ) => + Some((head, prevTuple)) + case _ => None + } + } +} + +/** + * Parses a few cases of arbitrary tuples. + * + * Scala 3 produces a few different trees for arbitrary tuples. Method `*:` is marked as inline. + * Under the hood it actually invokes `Tuples.cons` function: + * {{{ + * inline def *: [H, This >: this.type <: Tuple] (x: H): H *: This = + * runtime.Tuples.cons(x, this).asInstanceOf[H *: This] + * }}} + * So, at least we have to match Tuples.cons. + * However, it's not the only variation. Scala also produces a block with intermediate val `Tuple_this` definitions: + * {{{ + * { + * val Tuple_this: scala.Tuple$package.EmptyTuple.type = scala.Tuple$package.EmptyTuple + * val `Tuple_thisâ‚‚`: scala.*:[java.lang.String, scala.Tuple$package.EmptyTuple] = (scala.runtime.Tuples.cons("", Tuple_this).asInstanceOf[scala.*:[java.lang.String, scala.Tuple$package.EmptyTuple.type]]: scala.*:[java.lang.String, scala.Tuple$package.EmptyTuple]) + * + * (scala.runtime.Tuples.cons(1, `Tuple_thisâ‚‚`).asInstanceOf[scala.*:[scala.Int, scala.*:[java.lang.String, scala.Tuple$package.EmptyTuple]]]: scala.*:[scala.Int, scala.*:[java.lang.String, scala.Tuple$package.EmptyTuple]]) + * } + * }}} + */ +class ArbitraryTupleBlockParser(val rootParse: Parser)(using Quotes, TranspileConfig) + extends Parser(rootParse) + with PatternMatchingValues { + + import quotes.reflect.{Block => TBlock, Ident => TIdent, _} + + def attempt = { + case '{EmptyTuple} => + ast.Tuple(List()) + case '{$a *: EmptyTuple} => + val aAst = rootParse(a) + ast.Tuple(List(aAst)) + case inlined@Unseal(ArbitraryTupleConstructionInlined(singleValue, prevTuple)) => + val headAst = rootParse(singleValue.asExpr) + val prevTupleAst = rootParse(prevTuple.asExpr) + prevTupleAst match { + case ast.Tuple(lst) => ast.Tuple(headAst :: lst) + case _ => + throw IllegalArgumentException(s"Unexpected tuple ast ${prevTupleAst}") + } + case block@Unseal(TBlock(parts, ArbitraryTupleConstructionInlined(head,Typed(AsInstanceOf(TupleCons(head2,TIdent("Tuple_this"))), _)) )) if (parts.length > 0) => + val headAst = rootParse(head.asExpr) + val head2Ast = rootParse(head2.asExpr) + val partsAsts = headAst :: head2Ast :: parts.reverse.flatMap{ + case ValDef("Tuple_this", tpe, Some(TIdent("EmptyTuple"))) => List() + case ValDef("Tuple_this", tpe, Some(Typed(AsInstanceOf(TupleCons(next,TIdent("Tuple_this"))), _))) => List(rootParse(next.asExpr)) + case ValDef("Tuple_this", tpe, Some(unknown)) => + throw IllegalArgumentException(s"Unexpected Tuple_this = ${unknown.show}") + } + Tuple(partsAsts) + } +} class BlockParser(val rootParse: Parser)(using Quotes, TranspileConfig) extends Parser(rootParse) diff --git a/quill-sql/src/main/scala/io/getquill/parser/ParserHelpers.scala b/quill-sql/src/main/scala/io/getquill/parser/ParserHelpers.scala index b9a69d38..e0c912cd 100644 --- a/quill-sql/src/main/scala/io/getquill/parser/ParserHelpers.scala +++ b/quill-sql/src/main/scala/io/getquill/parser/ParserHelpers.scala @@ -490,6 +490,12 @@ object ParserHelpers { binds.zipWithIndex.flatMap { case (bind, idx) => tupleBindsPath(bind, path :+ s"_${idx + 1}") } + case Unapply(TypeApply(Select(TIdent("*:"), "unapply"), types), implicits, List(h, t)) => + List( + tupleBindsPath(h, path :+ s"head"), + tupleBindsPath(h, path :+ s"tail") + ) + .flatten // If it's a "case _ => ..." then that just translates into the body expression so we don't // need a clause to beta reduction over the entire partial-function case TIdent("_") => diff --git a/quill-sql/src/main/scala/io/getquill/quat/QuatMaking.scala b/quill-sql/src/main/scala/io/getquill/quat/QuatMaking.scala index 0a86a11c..cb5a35d2 100644 --- a/quill-sql/src/main/scala/io/getquill/quat/QuatMaking.scala +++ b/quill-sql/src/main/scala/io/getquill/quat/QuatMaking.scala @@ -190,6 +190,25 @@ trait QuatMakingBase { None } + object ArbitraryArityTupleType { + def unapply(using Quotes)(tpe: quotes.reflect.TypeRepr): Option[List[quotes.reflect.TypeRepr]] = + if (tpe.is[Tuple]) + Some(tupleParts(tpe)) + else + None + + @tailrec + def tupleParts(using Quotes)(tpe: quotes.reflect.TypeRepr, accum: List[quotes.reflect.TypeRepr] = Nil): List[quotes.reflect.TypeRepr] = + tpe.asType match { + case '[h *: t] => + val htpe = quotes.reflect.TypeRepr.of[h] + val ttpe = quotes.reflect.TypeRepr.of[t] + tupleParts(ttpe, htpe :: accum) + case '[EmptyTuple] => + accum.reverse + } + } + object OptionType { def unapply(using Quotes)(tpe: quotes.reflect.TypeRepr): Option[quotes.reflect.TypeRepr] = { import quotes.reflect._ @@ -384,6 +403,9 @@ trait QuatMakingBase { case CaseClassBaseType(name, fields) if !existsEncoderFor(tpe) || tpe <:< TypeRepr.of[Udt] => Quat.Product(name, fields.map { case (fieldName, fieldType) => (fieldName, parseType(fieldType)) }) + case ArbitraryArityTupleType(tupleParts) => + Quat.Product("Tuple", tupleParts.zipWithIndex.map { case (fieldType, idx) => (s"_${idx + 1}", parseType(fieldType)) }) + // If we are already inside a bounded type, treat an arbitrary type as a interface list case ArbitraryBaseType(name, fields) if (boundedInterfaceType) => Quat.Product(name, fields.map { case (fieldName, fieldType) => (fieldName, parseType(fieldType)) }) diff --git a/quill-sql/src/test/scala/io/getquill/ArbitraryTupleSpec.scala b/quill-sql/src/test/scala/io/getquill/ArbitraryTupleSpec.scala new file mode 100644 index 00000000..bc0ac871 --- /dev/null +++ b/quill-sql/src/test/scala/io/getquill/ArbitraryTupleSpec.scala @@ -0,0 +1,122 @@ +package io.getquill + +import io.getquill.context.ExecutionType.Static +import io.getquill.context.mirror.{MirrorSession, Row} +import io.getquill.generic.TupleMember + +class ArbitraryTupleSpec extends Spec { + + val ctx = new MirrorContext(PostgresDialect, Literal) + import ctx._ + + type MyRow1 = (Int, String) + type MyRow2 = Int *: String *: EmptyTuple + + inline def myRow1Query = quote { + querySchema[MyRow1]("my_table", t => t._1 -> "int_field", t => t._2 -> "string_field") + } + + inline def myRow2Query = quote { + querySchema[MyRow2]("my_table", t => t._1 -> "int_field", t => t._2 -> "string_field") + } + + "ordinary tuple" in { + val result = ctx.run(myRow1Query) + + result.string mustEqual "SELECT x.int_field, x.string_field FROM my_table x" + result.extractor(Row(123, "St"), MirrorSession.default) mustEqual + (123, "St") + } + + "ordinary tuple swap" in { + + transparent inline def swapped: Quoted[EntityQuery[(String, Int)]] = quote { + myRow1Query.map { + case (i, s) => (s, i) + } + } + + val result = ctx.run(swapped) + + result.string mustEqual "SELECT x$1.string_field AS _1, x$1.int_field AS _2 FROM my_table x$1" + require(result.extractor(Row("St", 123), MirrorSession.default) == ("St", 123)) + } + + "arbitrary long tuple" in { + val result = ctx.run(myRow2Query) + + result.extractor(Row(123, "St"), MirrorSession.default) mustEqual + (123, "St") + } + + "get field of arbitrary long tuple" in { + inline def g = quote{ + myRow2Query.map{ + case h *: tail => h + } + } + val result = ctx.run(g) + + result.extractor(Row(123, "St"), MirrorSession.default) mustEqual + (123) + } + + "decode empty tuple" in { + inline def g = quote { + myRow2Query.map{ + case (_, _) => EmptyTuple + } + } + + val result = ctx.run(g) + + result.extractor(Row(123, "St"), MirrorSession.default) mustEqual EmptyTuple + } + + "construct tuple1" in { + inline def g = quote { + myRow1Query.map { + case (i, s) => i *: EmptyTuple + } + } + + val result = ctx.run(g) + + result.string mustEqual "SELECT x$1.int_field AS _1 FROM my_table x$1" + result.extractor(Row(123, "St"), MirrorSession.default) mustEqual + Tuple1(123) + } + + "construct arbitrary tuple" in { + inline def g = quote { + myRow1Query.map { + case (i, s) => s *: i *: EmptyTuple + } + } + val result = ctx.run(g) + + result.string mustEqual "SELECT x$1.string_field AS _1, x$1.int_field AS _2 FROM my_table x$1" + result.extractor(Row("St", 123), MirrorSession.default) mustEqual ("St", 123) + + } + + "constant arbitrary tuple" in { + inline def g = quote { + 123 *: "St" *: true *: (3.14 *: EmptyTuple) + } + val result = ctx.run(g) + result.string mustEqual "SELECT 123 AS _1, 'St' AS _2, true AS _3, 3.14 AS _4" + result.info.executionType mustEqual Static + result.extractor(Row(123, "St", true, 3.14), MirrorSession.default) mustEqual (123, "St", true, 3.14) + } + + "constant arbitrary tuple 1" in { + inline def g = quote { + (3.14 *: EmptyTuple) + } + val result = ctx.run(g) + result.string mustEqual "SELECT 3.14 AS _1" + result.info.executionType mustEqual Static + result.extractor(Row(3.14), MirrorSession.default) mustEqual Tuple1(3.14) + } +}