Skip to content

Commit

Permalink
inserted record can be returned from query
Browse files Browse the repository at this point in the history
  • Loading branch information
rolandjohann authored and deusaquilus committed Jul 4, 2019
1 parent 08cf5b3 commit 7219984
Show file tree
Hide file tree
Showing 83 changed files with 1,930 additions and 395 deletions.
Expand Up @@ -17,7 +17,7 @@ class MysqlAsyncContext[N <: NamingStrategy](naming: N, pool: PartitionedConnect
def this(naming: N, config: Config) = this(naming, MysqlAsyncContextConfig(config))
def this(naming: N, configPrefix: String) = this(naming, LoadConfig(configPrefix))

override protected def extractActionResult[O](returningColumn: String, returningExtractor: Extractor[O])(result: DBQueryResult): O = {
override protected def extractActionResult[O](returningAction: ReturnAction, returningExtractor: Extractor[O])(result: DBQueryResult): O = {
result match {
case r: MySQLQueryResult =>
returningExtractor(new ArrayRowData(0, Map.empty, Array(r.lastInsertId)))
Expand Down
@@ -1,9 +1,10 @@
package io.getquill.context.async.mysql

import com.github.mauricio.async.db.QueryResult
import io.getquill.ReturnAction.ReturnColumns

import scala.concurrent.ExecutionContext.Implicits.global
import io.getquill.{ Literal, MysqlAsyncContext, Spec }
import io.getquill.{ Literal, MysqlAsyncContext, ReturnAction, Spec }

class MysqlAsyncContextSpec extends Spec {

Expand All @@ -18,7 +19,7 @@ class MysqlAsyncContextSpec extends Spec {

"Insert with returning with single column table" in {
val inserted: Long = await(testContext.run {
qr4.insert(lift(TestEntity4(0))).returning(_.i)
qr4.insert(lift(TestEntity4(0))).returningGenerated(_.i)
})
await(testContext.run(qr4.filter(_.i == lift(inserted))))
.head.i mustBe inserted
Expand All @@ -35,13 +36,13 @@ class MysqlAsyncContextSpec extends Spec {
"cannot extract" in {
object ctx extends MysqlAsyncContext(Literal, "testMysqlDB") {
override def extractActionResult[O](
returningColumn: String,
returningAction: ReturnAction,
returningExtractor: ctx.Extractor[O]
)(result: QueryResult) =
super.extractActionResult(returningColumn, returningExtractor)(result)
super.extractActionResult(returningAction, returningExtractor)(result)
}
intercept[IllegalStateException] {
ctx.extractActionResult("w/e", row => 1)(new QueryResult(0, "w/e"))
ctx.extractActionResult(ReturnColumns(List("w/e")), row => 1)(new QueryResult(0, "w/e"))
}
ctx.close
}
Expand Down
Expand Up @@ -35,7 +35,7 @@ class ProductMysqlAsyncSpec extends ProductSpec {
val prd = Product(0L, "test1", 1L)
val inserted = await {
testContext.run {
product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id)
product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id)
}
}
val returnedProduct = await(testContext.run(productById(lift(inserted)))).head
Expand All @@ -47,7 +47,7 @@ class ProductMysqlAsyncSpec extends ProductSpec {
"Single insert with free variable and explicit quotation" in {
val prd = Product(0L, "test2", 2L)
val q1 = quote {
product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id)
product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id)
}
val inserted = await(testContext.run(q1))
val returnedProduct = await(testContext.run(productById(lift(inserted)))).head
Expand All @@ -60,7 +60,7 @@ class ProductMysqlAsyncSpec extends ProductSpec {
case class Product(id: Id, description: String, sku: Long)
val prd = Product(Id(0L), "test2", 2L)
val q1 = quote {
query[Product].insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id)
query[Product].insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id)
}
await(testContext.run(q1)) mustBe a[Id]
}
Expand Down
@@ -1,9 +1,10 @@
package io.getquill

import com.github.mauricio.async.db.{ QueryResult => DBQueryResult }
import com.github.mauricio.async.db.pool.PartitionedConnectionPool
import com.github.mauricio.async.db.postgresql.PostgreSQLConnection
import com.github.mauricio.async.db.{ QueryResult => DBQueryResult }
import com.typesafe.config.Config
import io.getquill.ReturnAction.{ ReturnColumns, ReturnNothing, ReturnRecord }
import io.getquill.context.async.{ ArrayDecoders, ArrayEncoders, AsyncContext, UUIDObjectEncoding }
import io.getquill.util.LoadConfig
import io.getquill.util.Messages.fail
Expand All @@ -18,7 +19,7 @@ class PostgresAsyncContext[N <: NamingStrategy](naming: N, pool: PartitionedConn
def this(naming: N, config: Config) = this(naming, PostgresAsyncContextConfig(config))
def this(naming: N, configPrefix: String) = this(naming, LoadConfig(configPrefix))

override protected def extractActionResult[O](returningColumn: String, returningExtractor: Extractor[O])(result: DBQueryResult): O = {
override protected def extractActionResult[O](returningAction: ReturnAction, returningExtractor: Extractor[O])(result: DBQueryResult): O = {
result.rows match {
case Some(r) if r.nonEmpty =>
returningExtractor(r.head)
Expand All @@ -27,6 +28,14 @@ class PostgresAsyncContext[N <: NamingStrategy](naming: N, pool: PartitionedConn
}
}

override protected def expandAction(sql: String, returningColumn: String): String =
s"$sql RETURNING $returningColumn"
override protected def expandAction(sql: String, returningAction: ReturnAction): String =
returningAction match {
// The Postgres dialect will create SQL that has a 'RETURNING' clause so we don't have to add one.
case ReturnRecord => s"$sql"
// The Postgres dialect will not actually use these below variants but in case we decide to plug
// in some other dialect into this context...
case ReturnColumns(columns) => s"$sql RETURNING ${columns.mkString(", ")}"
case ReturnNothing => s"$sql"
}

}
@@ -1,9 +1,10 @@
package io.getquill.context.async.postgres

import com.github.mauricio.async.db.QueryResult
import io.getquill.ReturnAction.ReturnColumns

import scala.concurrent.ExecutionContext.Implicits.global
import io.getquill.{ Literal, PostgresAsyncContext, Spec }
import io.getquill.{ Literal, PostgresAsyncContext, ReturnAction, Spec }

class PostgresAsyncContextSpec extends Spec {

Expand All @@ -18,11 +19,18 @@ class PostgresAsyncContextSpec extends Spec {

"Insert with returning with single column table" in {
val inserted: Long = await(testContext.run {
qr4.insert(lift(TestEntity4(0))).returning(_.i)
qr4.insert(lift(TestEntity4(0))).returningGenerated(_.i)
})
await(testContext.run(qr4.filter(_.i == lift(inserted))))
.head.i mustBe inserted
}
"Insert with returning with multiple columns" in {
await(testContext.run(qr1.delete))
val inserted = await(testContext.run {
qr1.insert(lift(TestEntity("foo", 1, 18L, Some(123)))).returning(r => (r.i, r.s, r.o))
})
(1, "foo", Some(123)) mustBe inserted
}

"performIO" in {
await(performIO(runIO(qr4).transactional))
Expand All @@ -35,13 +43,13 @@ class PostgresAsyncContextSpec extends Spec {
"cannot extract" in {
object ctx extends PostgresAsyncContext(Literal, "testPostgresDB") {
override def extractActionResult[O](
returningColumn: String,
returningAction: ReturnAction,
returningExtractor: ctx.Extractor[O]
)(result: QueryResult) =
super.extractActionResult(returningColumn, returningExtractor)(result)
super.extractActionResult(returningAction, returningExtractor)(result)
}
intercept[IllegalStateException] {
ctx.extractActionResult("w/e", row => 1)(new QueryResult(0, "w/e"))
ctx.extractActionResult(ReturnColumns(List("w/e")), row => 1)(new QueryResult(0, "w/e"))
}
ctx.close
}
Expand Down
Expand Up @@ -4,14 +4,15 @@ import com.github.mauricio.async.db.Connection
import com.github.mauricio.async.db.{ QueryResult => DBQueryResult }
import com.github.mauricio.async.db.RowData
import com.github.mauricio.async.db.pool.PartitionedConnectionPool

import scala.concurrent.Await
import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.concurrent.duration.Duration
import scala.util.Try
import io.getquill.context.sql.SqlContext
import io.getquill.context.sql.idiom.SqlIdiom
import io.getquill.NamingStrategy
import io.getquill.{ NamingStrategy, ReturnAction }
import io.getquill.util.ContextLogger
import io.getquill.monad.ScalaFutureIOMonad
import io.getquill.context.{ Context, TranslateContext }
Expand Down Expand Up @@ -48,9 +49,9 @@ abstract class AsyncContext[D <: SqlIdiom, N <: NamingStrategy, C <: Connection]
case other => f(pool)
}

protected def extractActionResult[O](returningColumn: String, extractor: Extractor[O])(result: DBQueryResult): O
protected def extractActionResult[O](returningAction: ReturnAction, extractor: Extractor[O])(result: DBQueryResult): O

protected def expandAction(sql: String, returningColumn: String) = sql
protected def expandAction(sql: String, returningAction: ReturnAction) = sql

def probe(sql: String) =
Try {
Expand Down Expand Up @@ -88,12 +89,12 @@ abstract class AsyncContext[D <: SqlIdiom, N <: NamingStrategy, C <: Connection]
withConnection(_.sendPreparedStatement(sql, values)).map(_.rowsAffected)
}

def executeActionReturning[T](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T], returningColumn: String)(implicit ec: ExecutionContext): Future[T] = {
val expanded = expandAction(sql, returningColumn)
def executeActionReturning[T](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T], returningAction: ReturnAction)(implicit ec: ExecutionContext): Future[T] = {
val expanded = expandAction(sql, returningAction)
val (params, values) = prepare(Nil)
logger.logQuery(sql, params)
withConnection(_.sendPreparedStatement(expanded, values))
.map(extractActionResult(returningColumn, extractor))
.map(extractActionResult(returningAction, extractor))
}

def executeBatchAction(groups: List[BatchGroup])(implicit ec: ExecutionContext): Future[List[Long]] =
Expand Down
Expand Up @@ -2,6 +2,7 @@ package io.getquill.context.cassandra

import io.getquill.ast.{ TraversableOperation, _ }
import io.getquill.NamingStrategy
import io.getquill.context.CannotReturn
import io.getquill.util.Messages.fail
import io.getquill.idiom.Idiom
import io.getquill.idiom.StatementInterpolator._
Expand All @@ -10,7 +11,7 @@ import io.getquill.idiom.SetContainsToken
import io.getquill.idiom.Token
import io.getquill.util.Interleave

object CqlIdiom extends CqlIdiom
object CqlIdiom extends CqlIdiom with CannotReturn

trait CqlIdiom extends Idiom {

Expand All @@ -33,6 +34,7 @@ trait CqlIdiom extends Idiom {
case a: Operation => a.token
case a: Action => a.token
case a: Ident => a.token
case a: ExternalIdent => a.token
case a: Property => a.token
case a: Value => a.token
case a: Function => a.body.token
Expand Down Expand Up @@ -135,6 +137,10 @@ trait CqlIdiom extends Idiom {
case e => strategy.default(e.name).token
}

implicit def externalIdentTokenizer(implicit strategy: NamingStrategy): Tokenizer[ExternalIdent] = Tokenizer[ExternalIdent] {
case e => strategy.default(e.name).token
}

implicit def assignmentTokenizer(implicit propertyTokenizer: Tokenizer[Property], strategy: NamingStrategy): Tokenizer[Assignment] = Tokenizer[Assignment] {
case Assignment(alias, prop, value) =>
stmt"${prop.token} = ${value.token}"
Expand Down Expand Up @@ -175,6 +181,9 @@ trait CqlIdiom extends Idiom {
case _: Returning =>
fail(s"Cql doesn't support returning generated during insertion")

case _: ReturningGenerated =>
fail(s"Cql doesn't support returning generated during insertion")

case other =>
fail(s"Action ast can't be translated to cql: '$other'")
}
Expand Down
Expand Up @@ -38,10 +38,7 @@ class CqlIdiomSpec extends Spec {
"SELECT s FROM TestEntity WHERE i = 1 ORDER BY s ASC LIMIT 1"
}
"returning" in {
val q = quote {
query[TestEntity].insert(_.l -> 1L).returning(_.i)
}
"mirrorContext.run(q).string" mustNot compile
"mirrorContext.run(query[TestEntity].insert(_.l -> 1L).returning(_.i)).string" mustNot compile
}
}

Expand Down
12 changes: 6 additions & 6 deletions quill-core/src/main/scala/io/getquill/AsyncMirrorContext.scala
Expand Up @@ -56,11 +56,11 @@ class AsyncMirrorContext[Idiom <: BaseIdiom, Naming <: NamingStrategy](val idiom

case class ActionMirror(string: String, prepareRow: PrepareRow)(implicit val ec: ExecutionContext)

case class ActionReturningMirror[T](string: String, prepareRow: PrepareRow, extractor: Extractor[T], returningColumn: String)(implicit val ec: ExecutionContext)
case class ActionReturningMirror[T](string: String, prepareRow: PrepareRow, extractor: Extractor[T], returningBehavior: ReturnAction)(implicit val ec: ExecutionContext)

case class BatchActionMirror(groups: List[(String, List[Row])])(implicit val ec: ExecutionContext)

case class BatchActionReturningMirror[T](groups: List[(String, String, List[PrepareRow])], extractor: Extractor[T])(implicit val ec: ExecutionContext)
case class BatchActionReturningMirror[T](groups: List[(String, ReturnAction, List[PrepareRow])], extractor: Extractor[T])(implicit val ec: ExecutionContext)

case class QueryMirror[T](string: String, prepareRow: PrepareRow, extractor: Extractor[T])(implicit val ec: ExecutionContext)

Expand All @@ -74,8 +74,8 @@ class AsyncMirrorContext[Idiom <: BaseIdiom, Naming <: NamingStrategy](val idiom
Future(ActionMirror(string, prepare(Row())._2))

def executeActionReturning[O](string: String, prepare: Prepare = identityPrepare, extractor: Extractor[O],
returningColumn: String)(implicit ec: ExecutionContext) =
Future(ActionReturningMirror[O](string, prepare(Row())._2, extractor, returningColumn))
returningBehavior: ReturnAction)(implicit ec: ExecutionContext) =
Future(ActionReturningMirror[O](string, prepare(Row())._2, extractor, returningBehavior))

def executeBatchAction(groups: List[BatchGroup])(implicit ec: ExecutionContext) =
Future {
Expand All @@ -91,8 +91,8 @@ class AsyncMirrorContext[Idiom <: BaseIdiom, Naming <: NamingStrategy](val idiom
Future {
BatchActionReturningMirror[T](
groups.map {
case BatchGroupReturning(string, column, prepare) =>
(string, column, prepare.map(_(Row())._2))
case BatchGroupReturning(string, returningBehavior, prepare) =>
(string, returningBehavior, prepare.map(_(Row())._2))
}, extractor
)
}
Expand Down
12 changes: 6 additions & 6 deletions quill-core/src/main/scala/io/getquill/MirrorContext.scala
Expand Up @@ -41,11 +41,11 @@ class MirrorContext[Idiom <: BaseIdiom, Naming <: NamingStrategy](val idiom: Idi

case class ActionMirror(string: String, prepareRow: PrepareRow)

case class ActionReturningMirror[T](string: String, prepareRow: PrepareRow, extractor: Extractor[T], returningColumn: String)
case class ActionReturningMirror[T](string: String, prepareRow: PrepareRow, extractor: Extractor[T], returningBehavior: ReturnAction)

case class BatchActionMirror(groups: List[(String, List[Row])])

case class BatchActionReturningMirror[T](groups: List[(String, String, List[PrepareRow])], extractor: Extractor[T])
case class BatchActionReturningMirror[T](groups: List[(String, ReturnAction, List[PrepareRow])], extractor: Extractor[T])

case class QueryMirror[T](string: String, prepareRow: PrepareRow, extractor: Extractor[T])

Expand All @@ -59,8 +59,8 @@ class MirrorContext[Idiom <: BaseIdiom, Naming <: NamingStrategy](val idiom: Idi
ActionMirror(string, prepare(Row())._2)

def executeActionReturning[O](string: String, prepare: Prepare = identityPrepare, extractor: Extractor[O],
returningColumn: String) =
ActionReturningMirror[O](string, prepare(Row())._2, extractor, returningColumn)
returningBehavior: ReturnAction) =
ActionReturningMirror[O](string, prepare(Row())._2, extractor, returningBehavior)

def executeBatchAction(groups: List[BatchGroup]) =
BatchActionMirror {
Expand All @@ -73,8 +73,8 @@ class MirrorContext[Idiom <: BaseIdiom, Naming <: NamingStrategy](val idiom: Idi
def executeBatchActionReturning[T](groups: List[BatchGroupReturning], extractor: Extractor[T]) =
BatchActionReturningMirror[T](
groups.map {
case BatchGroupReturning(string, column, prepare) =>
(string, column, prepare.map(_(Row())._2))
case BatchGroupReturning(string, returningBehavior, prepare) =>
(string, returningBehavior, prepare.map(_(Row())._2))
}, extractor
)

Expand Down

0 comments on commit 7219984

Please sign in to comment.