From 060204e205ddbfdcd5d26bdc856b39839a42b14a Mon Sep 17 00:00:00 2001 From: Alexander Ioffe Date: Thu, 1 Dec 2022 18:32:40 -0500 Subject: [PATCH] Using zio-direct in Quill ZioJdbcContext --- build.sbt | 5 +- .../io/getquill/CassandraZioContext.scala | 19 +++++ .../context/qzio/ZioJdbcContext.scala | 81 ++++++++++--------- .../qzio/ZioJdbcUnderlyingContext.scala | 24 +++--- .../io/getquill/examples/GenericDao.scala | 5 ++ 5 files changed, 85 insertions(+), 49 deletions(-) diff --git a/build.sbt b/build.sbt index 584a506ee..73eb87129 100644 --- a/build.sbt +++ b/build.sbt @@ -240,7 +240,8 @@ lazy val `quill-zio` = Test / fork := true, libraryDependencies ++= Seq( "dev.zio" %% "zio" % "2.0.2", - "dev.zio" %% "zio-streams" % "2.0.2" + "dev.zio" %% "zio-streams" % "2.0.2", + "dev.zio" %% "zio-direct" % "1.0.0-RC1" ) ) .dependsOn(`quill-sql` % "compile->compile;test->test") @@ -345,7 +346,7 @@ lazy val basicSettings = Seq( ExclusionRule("org.scala-lang.modules", "scala-collection-compat_2.13") ), scalaVersion := { - if (isCommunityBuild) dottyLatestNightlyBuild().get else "3.1.3" + if (isCommunityBuild) dottyLatestNightlyBuild().get else "3.2.0" }, organization := "io.getquill", // The -e option is the 'error' report of ScalaTest. We want it to only make a log diff --git a/quill-cassandra-zio/src/main/scala/io/getquill/CassandraZioContext.scala b/quill-cassandra-zio/src/main/scala/io/getquill/CassandraZioContext.scala index 4265327c9..b1970e94f 100644 --- a/quill-cassandra-zio/src/main/scala/io/getquill/CassandraZioContext.scala +++ b/quill-cassandra-zio/src/main/scala/io/getquill/CassandraZioContext.scala @@ -95,6 +95,16 @@ class CassandraZioContext[+N <: NamingStrategy](val naming: N) } private[getquill] def execute(cql: String, prepare: Prepare, csession: CassandraZioSession, fetchSize: Option[Int]) = + /* + val p = prepareRowAndLog(cql, prepare).run + attempt { + fetchSize match { + case Some(value) => p.setPageSize + case None => p + } + } + ZIO.fromCompletionStage(csession.session.executeAsync(p)).await + */ simpleBlocking { prepareRowAndLog(cql, prepare) .mapAttempt { p => @@ -153,6 +163,15 @@ class CassandraZioContext[+N <: NamingStrategy](val naming: N) rows <- ZIO.attempt(rs.currentPage()) singleRow <- ZIO.attempt(handleSingleResult(cql, rows.asScala.map(row => extractor(row, csession)).toList)) } yield singleRow + + /* + val csession = ZIO.service[CassandraZioSession].run + val rs = execute(cql, prepare, csession, Some(1)).run + unsafe { + rows = rs.currentPage() + singleRow = handleSingleResult(cql, rows.asScala.map(row => extractor(row, csession)).toList) + } + */ } def executeAction(cql: String, prepare: Prepare = identityPrepare)(info: ExecutionInfo, dc: Runner): CIO[Unit] = simpleBlocking { diff --git a/quill-jdbc-zio/src/main/scala/io/getquill/context/qzio/ZioJdbcContext.scala b/quill-jdbc-zio/src/main/scala/io/getquill/context/qzio/ZioJdbcContext.scala index 5a33c2f23..20eaab2d9 100644 --- a/quill-jdbc-zio/src/main/scala/io/getquill/context/qzio/ZioJdbcContext.scala +++ b/quill-jdbc-zio/src/main/scala/io/getquill/context/qzio/ZioJdbcContext.scala @@ -16,6 +16,10 @@ import io.getquill.* import io.getquill.jdbczio.Quill import zio.ZIO.attemptBlocking import zio.ZIO.blocking +import zio.direct._ +import zio.direct.core.metaprog.Verify +import zio.direct.Dsl.Params +import zio.Scope /** * Quill context that executes JDBC queries inside of ZIO. Unlike most other contexts @@ -180,47 +184,52 @@ abstract class ZioJdbcContext[+Dialect <: SqlIdiom, +Naming <: NamingStrategy] e * */ def transaction[R <: DataSource, A](op: ZIO[R, Throwable, A]): ZIO[R, Throwable, A] = { - blocking(currentConnection.get.flatMap { - // We can just return the op in the case that there is already a connection set on the fiber ref - // because the op is execute___ which will lookup the connection from the fiber ref via onConnection/onConnectionStream - // This will typically happen for nested transactions e.g. transaction(transaction(a *> b) *> c) - case Some(connection) => op - case None => - val connection = for { - env <- ZIO.service[DataSource] - connection <- scopedBestEffort(attemptBlocking(env.getConnection)) - // Get the current value of auto-commit - prevAutoCommit <- attemptBlocking(connection.getAutoCommit) - // Disable auto-commit since we need to be able to roll back. Once everything is done, set it - // to whatever the previous value was. - _ <- ZIO.acquireRelease(attemptBlocking(connection.setAutoCommit(false))) { _ => - attemptBlocking(connection.setAutoCommit(prevAutoCommit)).orDie - } - _ <- ZIO.acquireRelease(currentConnection.set(Some(connection))) { _ => - // Note. We are failing the fiber if auto-commit reset fails. For some circumstances this may be too aggresive. - // If the connection pool e.g. Hikari resets this property for a recycled connection anyway doing it here - // might not be necessary - currentConnection.set(None) - } - // Once the `use` of this outer-Scoped is done, rollback the connection if needed - _ <- ZIO.addFinalizerExit { - case Success(_) => blocking(ZIO.succeed(connection.commit())) - case Failure(cause) => blocking(ZIO.succeed(connection.rollback())) - } - } yield () - - ZIO.scoped(connection *> op) - }) + defer { + currentConnection.get.run match { + case Some(conn) => op.run + case None => + ZIO.scoped(defer { + defer { + val env = ZIO.service[DataSource].run + val connection = scopedBestEffort(attemptBlocking(env.getConnection)).run + // Get the current value of auto-commit + val prevAutoCommit = attemptBlocking(connection.getAutoCommit).run + // Disable auto-commit since we need to be able to roll back. Once everything is done, set it + // to whatever the previous value was. + ZIO.acquireRelease(attemptBlocking(connection.setAutoCommit(false))) { _ => + attemptBlocking(connection.setAutoCommit(prevAutoCommit)).orDie + }.run + ZIO.acquireRelease(currentConnection.set(Some(connection))) { _ => + // Note. We are failing the fiber if auto-commit reset fails. For some circumstances this may be too aggresive. + // If the connection pool e.g. Hikari resets this property for a recycled connection anyway doing it here + // might not be necessary + currentConnection.set(None) + }.run + ZIO.addFinalizerExit { + case Success(_) => blocking(ZIO.succeed(connection.commit())) + case Failure(cause) => blocking(ZIO.succeed(connection.rollback())) + }.run + }.run + op.run + }).run + } + } } private def onConnection[T](qlio: ZIO[Connection, SQLException, T]): ZIO[DataSource, SQLException, T] = - currentConnection.get.flatMap { - case Some(connection) => - blocking(qlio.provideEnvironment(ZEnvironment(connection))) - case None => - blocking(qlio.provideLayer(Quill.Connection.acquireScoped)) + defer { + currentConnection.get.run match { + case Some(connection) => + blocking(qlio.provideEnvironment(ZEnvironment(connection))).run + case None => + blocking(qlio.provideLayer(Quill.Connection.acquireScoped)).run + } } + def foo(): Unit = { + val iter: ZIO[Scope, Nothing, Iterator[Either[Nothing, Option[Connection]]]] = ZStream.fromZIO(currentConnection.get).toIterator + } + private def onConnectionStream[T](qstream: ZStream[Connection, SQLException, T]): ZStream[DataSource, SQLException, T] = streamBlocker *> ZStream.fromZIO(currentConnection.get).flatMap { case Some(connection) => diff --git a/quill-jdbc-zio/src/main/scala/io/getquill/context/qzio/ZioJdbcUnderlyingContext.scala b/quill-jdbc-zio/src/main/scala/io/getquill/context/qzio/ZioJdbcUnderlyingContext.scala index 4d6c34789..94fa5ad1c 100644 --- a/quill-jdbc-zio/src/main/scala/io/getquill/context/qzio/ZioJdbcUnderlyingContext.scala +++ b/quill-jdbc-zio/src/main/scala/io/getquill/context/qzio/ZioJdbcUnderlyingContext.scala @@ -16,6 +16,7 @@ import javax.sql.DataSource import scala.reflect.ClassTag import scala.util.Try import scala.annotation.targetName +import zio.direct._ abstract class ZioJdbcUnderlyingContext[+Dialect <: SqlIdiom, +Naming <: NamingStrategy] extends ZioContext[Dialect, Naming] with JdbcContextVerbExecute[Dialect, Naming] @@ -106,13 +107,13 @@ abstract class ZioJdbcUnderlyingContext[+Dialect <: SqlIdiom, +Naming <: NamingS * they can be generalized to Something <: Connection. E.g. `Connection with OtherStuff` generalizes to `Something <: Connection`. */ private[getquill] def withoutAutoCommit[R <: Connection, A, E <: Throwable: ClassTag](f: ZIO[R, E, A]): ZIO[R, E, A] = { - for { - conn <- ZIO.service[Connection] - autoCommitPrev = conn.getAutoCommit - r <- ZIO.acquireReleaseWith(sqlEffect(conn))(conn => ZIO.succeed(conn.setAutoCommit(autoCommitPrev))) { conn => + defer { + val conn = ZIO.service[Connection].run + val autoCommitPrev = conn.getAutoCommit + ZIO.acquireReleaseWith(sqlEffect(conn))(conn => ZIO.succeed(conn.setAutoCommit(autoCommitPrev))) { conn => sqlEffect(conn.setAutoCommit(false)).flatMap(_ => f) - }.refineToOrDie[E] - } yield r + }.refineToOrDie[E].run + } } private[getquill] def streamWithoutAutoCommit[A](f: ZStream[Connection, Throwable, A]): ZStream[Connection, Throwable, A] = { @@ -179,11 +180,12 @@ abstract class ZioJdbcUnderlyingContext[+Dialect <: SqlIdiom, +Naming <: NamingS val scopedEnv: ZStream[Connection, Throwable, (Connection, PrepareRow, ResultSet)] = ZStream.scoped { - for { - conn <- ZIO.service[Connection] - ps <- scopedBestEffort(ZIO.attempt(prepareStatement(conn))) - rs <- scopedBestEffort(ZIO.attempt(ps.executeQuery())) - } yield (conn, ps, rs) + defer { + val conn = ZIO.service[Connection].run + val ps = scopedBestEffort(ZIO.attempt(prepareStatement(conn))).run + val rs = scopedBestEffort(ZIO.attempt(ps.executeQuery())).run + (conn, ps, rs) + } } val outStream: ZStream[Connection, Throwable, T] = diff --git a/quill-jdbc-zio/src/test/scala/io/getquill/examples/GenericDao.scala b/quill-jdbc-zio/src/test/scala/io/getquill/examples/GenericDao.scala index 00a8b914d..44fd69a7a 100644 --- a/quill-jdbc-zio/src/test/scala/io/getquill/examples/GenericDao.scala +++ b/quill-jdbc-zio/src/test/scala/io/getquill/examples/GenericDao.scala @@ -23,6 +23,9 @@ class Repo[T <: { def id: Int }](ds: DataSource) { inline def getById(inline id: Int) = run(query[T].filter(t => t.id == lift(id))).map(_.headOption).provideEnvironment(env) + inline def deleteById(inline id: Int) = + run(query[T].filter(t => t.id == lift(id)).delete).provideEnvironment(env) + inline def insert(inline t: T) = run(query[T].insertValue(lift(t)).returning(_.id)).provideEnvironment(env) @@ -46,6 +49,8 @@ object StructureBasedRepo extends ZIOAppDefault { joeId <- peopleRepo.insert(joe) joeNew <- peopleRepo.getById(joeId) allJoes <- peopleRepo.searchByField(p => p.first == "Joe") + _ <- peopleRepo.deleteById(joeId) + allJoes1 <- peopleRepo.searchByField(p => p.first == "Joe") _ <- printLine("==== joe: " + joe) *> printLine("==== joeId: " + joeId) *>